From d66be2c1d3e00173332b4190776743362c973cdb Mon Sep 17 00:00:00 2001 From: LandonTClipp Date: Sun, 19 Nov 2023 20:29:51 -0600 Subject: [PATCH] Add more code to plumb through all the values needed by moq registry methods In this commit, we gather all the template data needed by the moq logic to generate its template. This is untested as of yet. TODO: need to start testing this works by calling upon `moq` in `.mockery.yaml`. --- pkg/interface.go | 3 ++ pkg/outputter.go | 6 ++- pkg/parse.go | 22 ++++----- pkg/registry/method_scope.go | 2 +- pkg/registry/package.go | 5 +- pkg/registry/registry.go | 11 ++--- pkg/registry/registry_test.go | 10 ---- pkg/registry/var.go | 2 +- pkg/template_generator.go | 93 ++++++++++++++++++++++++++++++----- 9 files changed, 108 insertions(+), 46 deletions(-) diff --git a/pkg/interface.go b/pkg/interface.go index 64168107..d3b96986 100644 --- a/pkg/interface.go +++ b/pkg/interface.go @@ -3,6 +3,8 @@ package pkg import ( "go/ast" "go/types" + + "golang.org/x/tools/go/packages" ) // Interface type represents the target type that we will generate a mock for. @@ -14,6 +16,7 @@ type Interface struct { QualifiedName string // Path to the package of the target type. FileName string File *ast.File + PackagesPackage *packages.Package Pkg TypesPackage NamedType *types.Named IsFunction bool // If true, this instance represents a function, otherwise it's an interface. diff --git a/pkg/outputter.go b/pkg/outputter.go index 82cf76d1..4f8cd99b 100644 --- a/pkg/outputter.go +++ b/pkg/outputter.go @@ -333,7 +333,11 @@ func (o *Outputter) Generate(ctx context.Context, iface *Interface) error { config := TemplateGeneratorConfig{ Style: interfaceConfig.Style, } - generator := NewTemplateGenerator(config) + generator, err := NewTemplateGenerator(iface.PackagesPackage, config) + if err != nil { + return fmt.Errorf("creating template generator: %w", err) + } + fmt.Printf("generator: %v\n", generator) } diff --git a/pkg/parse.go b/pkg/parse.go index c3836feb..e647c9ce 100644 --- a/pkg/parse.go +++ b/pkg/parse.go @@ -190,7 +190,7 @@ func (p *Parser) Find(name string) (*Interface, error) { for _, entry := range p.entries { for _, iface := range entry.interfaces { if iface == name { - list := p.packageInterfaces(entry.pkg.Types, entry.fileName, []string{name}, nil) + list := p.packageInterfaces(entry, []string{name}, nil) if len(list) > 0 { return list[0], nil } @@ -204,7 +204,7 @@ func (p *Parser) Interfaces() []*Interface { ifaces := make(sortableIFaceList, 0) for _, entry := range p.entries { declaredIfaces := entry.interfaces - ifaces = p.packageInterfaces(entry.pkg.Types, entry.fileName, declaredIfaces, ifaces) + ifaces = p.packageInterfaces(entry, declaredIfaces, ifaces) } sort.Sort(ifaces) @@ -212,11 +212,10 @@ func (p *Parser) Interfaces() []*Interface { } func (p *Parser) packageInterfaces( - pkg *types.Package, - fileName string, + entry *parserEntry, declaredInterfaces []string, ifaces []*Interface) []*Interface { - scope := pkg.Scope() + scope := entry.pkg.Types.Scope() for _, name := range declaredInterfaces { obj := scope.Lookup(name) if obj == nil { @@ -235,11 +234,12 @@ func (p *Parser) packageInterfaces( } elem := &Interface{ - Name: name, - Pkg: pkg, - QualifiedName: pkg.Path(), - FileName: fileName, - NamedType: typ, + Name: name, + PackagesPackage: entry.pkg, + Pkg: entry.pkg.Types, + QualifiedName: entry.pkg.Types.Path(), + FileName: entry.fileName, + NamedType: typ, } iface, ok := typ.Underlying().(*types.Interface) @@ -266,8 +266,6 @@ type TypesPackage interface { Path() string } - - type sortableIFaceList []*Interface func (s sortableIFaceList) Len() int { diff --git a/pkg/registry/method_scope.go b/pkg/registry/method_scope.go index 65b56162..8d301342 100644 --- a/pkg/registry/method_scope.go +++ b/pkg/registry/method_scope.go @@ -80,7 +80,7 @@ func (m MethodScope) populateImports(t types.Type, imports map[string]*Package) switch t := t.(type) { case *types.Named: if pkg := t.Obj().Pkg(); pkg != nil { - imports[stripVendorPath(pkg.Path())] = m.registry.AddImport(pkg) + imports[pkg.Path()] = m.registry.AddImport(pkg) } // The imports of a Type with a TypeList must be added to the imports list // For example: Foo[otherpackage.Bar] , must have otherpackage imported diff --git a/pkg/registry/package.go b/pkg/registry/package.go index d4945ea6..90192286 100644 --- a/pkg/registry/package.go +++ b/pkg/registry/package.go @@ -1,7 +1,6 @@ package registry import ( - "path" "strings" ) @@ -42,7 +41,7 @@ func (p *Package) Path() string { return "" } - return stripVendorPath(p.pkg.Path()) + return p.pkg.Path() } var replacer = strings.NewReplacer( @@ -72,8 +71,6 @@ func (p Package) uniqueName(lvl int) string { return name } - - func min(a, b int) int { if a < b { return a diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 75a70c04..3cc978a7 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -8,7 +8,6 @@ import ( "sort" "strings" - "github.com/chigopher/pathlib" "golang.org/x/tools/go/packages" ) @@ -18,19 +17,19 @@ import ( // qualifiers. type Registry struct { srcPkgName string + srcPkgPath string srcPkgTypes *types.Package - outputPath *pathlib.Path aliases map[string]string imports map[string]*Package } // New loads the source package info and returns a new instance of // Registry. -func New(srcPkg *packages.Package, outputPath *pathlib.Path) (*Registry, error) { +func New(srcPkg *packages.Package) (*Registry, error) { return &Registry{ srcPkgName: srcPkg.Name, + srcPkgPath: srcPkg.PkgPath, srcPkgTypes: srcPkg.Types, - outputPath: outputPath, aliases: parseImportsAliases(srcPkg.Syntax), imports: make(map[string]*Package), }, nil @@ -79,8 +78,8 @@ func (r *Registry) MethodScope() *MethodScope { // suitable alias if there are any conflicts with previously imported // packages. func (r *Registry) AddImport(pkg *types.Package) *Package { - path := stripVendorPath(pkg.Path()) - if pathlib.NewPath(path).Equals(r.outputPath) { + path := pkg.Path() + if pkg.Path() == r.srcPkgPath { return nil } diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go index 48f91328..b2a276fb 100644 --- a/pkg/registry/registry_test.go +++ b/pkg/registry/registry_test.go @@ -1,11 +1 @@ package registry - -import ( - "testing" -) - -func BenchmarkNew(b *testing.B) { - for i := 0; i < b.N; i++ { - New("../../pkg/moq/testpackages/example", "") - } -} diff --git a/pkg/registry/var.go b/pkg/registry/var.go index 081a17c7..de8da2dc 100644 --- a/pkg/registry/var.go +++ b/pkg/registry/var.go @@ -30,7 +30,7 @@ func (v Var) TypeString() string { // packageQualifier is a types.Qualifier. func (v Var) packageQualifier(pkg *types.Package) string { - path := stripVendorPath(pkg.Path()) + path := pkg.Path() if v.moqPkgPath != "" && v.moqPkgPath == path { return "" } diff --git a/pkg/template_generator.go b/pkg/template_generator.go index 27da8568..b59c0f8b 100644 --- a/pkg/template_generator.go +++ b/pkg/template_generator.go @@ -1,39 +1,110 @@ package pkg import ( + "bytes" + "context" + "fmt" + "go/types" + + "github.com/chigopher/pathlib" + "github.com/rs/zerolog" "github.com/vektra/mockery/v2/pkg/config" + "github.com/vektra/mockery/v2/pkg/registry" "github.com/vektra/mockery/v2/pkg/template" + "golang.org/x/tools/go/packages" ) type TemplateGeneratorConfig struct { Style string } type TemplateGenerator struct { - config TemplateGeneratorConfig + config TemplateGeneratorConfig + registry *registry.Registry } -func NewTemplateGenerator(config TemplateGeneratorConfig) *TemplateGenerator { - return &TemplateGenerator{ - config: config, +func NewTemplateGenerator(srcPkg *packages.Package, config TemplateGeneratorConfig) (*TemplateGenerator, error) { + reg, err := registry.New(srcPkg) + if err != nil { + return nil, fmt.Errorf("creating new registry: %w", err) } + + return &TemplateGenerator{ + config: config, + registry: reg, + }, nil } -func (g *TemplateGenerator) Generate(iface *Interface, ifaceConfig *config.Config) error { - templ, err := template.New(g.config.Style) - if err != nil { - return err - } +func (g *TemplateGenerator) Generate(ctx context.Context, iface *Interface, ifaceConfig *config.Config) error { + log := zerolog.Ctx(ctx) + log.Info().Msg("generating mock for interface") + imports := Imports{} for _, method := range iface.Methods() { method.populateImports(imports) } - // TODO: Work on getting these imports into the template + methods := make([]template.MethodData, iface.ActualInterface.NumMethods()) + + for i := 0; i < iface.ActualInterface.NumMethods(); i++ { + method := iface.ActualInterface.Method(i) + methodScope := g.registry.MethodScope() + + signature := method.Type().(*types.Signature) + params := make([]template.ParamData, signature.Params().Len()) + for j := 0; j < signature.Params().Len(); j++ { + param := signature.Params().At(j) + params[j] = template.ParamData{ + Var: methodScope.AddVar(param, ""), + Variadic: signature.Variadic() && j == signature.Params().Len()-1, + } + } + + returns := make([]template.ParamData, signature.Results().Len()) + for j := 0; j < signature.Results().Len(); j++ { + param := signature.Results().At(j) + returns[j] = template.ParamData{ + Var: methodScope.AddVar(param, "Out"), + Variadic: false, + } + } + methods[i] = template.MethodData{ + Name: method.Name(), + Params: params, + Returns: returns, + } + + } + + // For now, mockery only supports one mock per file, which is why we're creating + // a single-element list. moq seems to have supported multiple mocks per file. + mockData := []template.MockData{ + { + InterfaceName: iface.Name, + MockName: ifaceConfig.MockName, + Methods: methods, + }, + } data := template.Data{ PkgName: ifaceConfig.Outpkg, SrcPkgQualifier: iface.Pkg.Name() + ".", - Imports: + Imports: g.registry.Imports(), + Mocks: mockData, } + templ, err := template.New(g.config.Style) + if err != nil { + return fmt.Errorf("creating new template: %w", err) + } + + var buf bytes.Buffer + if err := templ.Execute(&buf, data); err != nil { + return fmt.Errorf("executing template: %w", err) + } + + outPath := pathlib.NewPath(ifaceConfig.Dir).Join(ifaceConfig.FileName) + if err := outPath.WriteFile(buf.Bytes()); err != nil { + log.Error().Err(err).Msg("couldn't write to output file") + return fmt.Errorf("writing to output file: %w", err) + } return nil }