diff --git a/pkg/interface.go b/pkg/interface.go index 64168107..d3b96986 100644 --- a/pkg/interface.go +++ b/pkg/interface.go @@ -3,6 +3,8 @@ package pkg import ( "go/ast" "go/types" + + "golang.org/x/tools/go/packages" ) // Interface type represents the target type that we will generate a mock for. @@ -14,6 +16,7 @@ type Interface struct { QualifiedName string // Path to the package of the target type. FileName string File *ast.File + PackagesPackage *packages.Package Pkg TypesPackage NamedType *types.Named IsFunction bool // If true, this instance represents a function, otherwise it's an interface. diff --git a/pkg/outputter.go b/pkg/outputter.go index 82cf76d1..4f8cd99b 100644 --- a/pkg/outputter.go +++ b/pkg/outputter.go @@ -333,7 +333,11 @@ func (o *Outputter) Generate(ctx context.Context, iface *Interface) error { config := TemplateGeneratorConfig{ Style: interfaceConfig.Style, } - generator := NewTemplateGenerator(config) + generator, err := NewTemplateGenerator(iface.PackagesPackage, config) + if err != nil { + return fmt.Errorf("creating template generator: %w", err) + } + fmt.Printf("generator: %v\n", generator) } diff --git a/pkg/parse.go b/pkg/parse.go index c3836feb..e647c9ce 100644 --- a/pkg/parse.go +++ b/pkg/parse.go @@ -190,7 +190,7 @@ func (p *Parser) Find(name string) (*Interface, error) { for _, entry := range p.entries { for _, iface := range entry.interfaces { if iface == name { - list := p.packageInterfaces(entry.pkg.Types, entry.fileName, []string{name}, nil) + list := p.packageInterfaces(entry, []string{name}, nil) if len(list) > 0 { return list[0], nil } @@ -204,7 +204,7 @@ func (p *Parser) Interfaces() []*Interface { ifaces := make(sortableIFaceList, 0) for _, entry := range p.entries { declaredIfaces := entry.interfaces - ifaces = p.packageInterfaces(entry.pkg.Types, entry.fileName, declaredIfaces, ifaces) + ifaces = p.packageInterfaces(entry, declaredIfaces, ifaces) } sort.Sort(ifaces) @@ -212,11 +212,10 @@ func (p *Parser) Interfaces() []*Interface { } func (p *Parser) packageInterfaces( - pkg *types.Package, - fileName string, + entry *parserEntry, declaredInterfaces []string, ifaces []*Interface) []*Interface { - scope := pkg.Scope() + scope := entry.pkg.Types.Scope() for _, name := range declaredInterfaces { obj := scope.Lookup(name) if obj == nil { @@ -235,11 +234,12 @@ func (p *Parser) packageInterfaces( } elem := &Interface{ - Name: name, - Pkg: pkg, - QualifiedName: pkg.Path(), - FileName: fileName, - NamedType: typ, + Name: name, + PackagesPackage: entry.pkg, + Pkg: entry.pkg.Types, + QualifiedName: entry.pkg.Types.Path(), + FileName: entry.fileName, + NamedType: typ, } iface, ok := typ.Underlying().(*types.Interface) @@ -266,8 +266,6 @@ type TypesPackage interface { Path() string } - - type sortableIFaceList []*Interface func (s sortableIFaceList) Len() int { diff --git a/pkg/registry/method_scope.go b/pkg/registry/method_scope.go index 65b56162..8d301342 100644 --- a/pkg/registry/method_scope.go +++ b/pkg/registry/method_scope.go @@ -80,7 +80,7 @@ func (m MethodScope) populateImports(t types.Type, imports map[string]*Package) switch t := t.(type) { case *types.Named: if pkg := t.Obj().Pkg(); pkg != nil { - imports[stripVendorPath(pkg.Path())] = m.registry.AddImport(pkg) + imports[pkg.Path()] = m.registry.AddImport(pkg) } // The imports of a Type with a TypeList must be added to the imports list // For example: Foo[otherpackage.Bar] , must have otherpackage imported diff --git a/pkg/registry/package.go b/pkg/registry/package.go index d4945ea6..90192286 100644 --- a/pkg/registry/package.go +++ b/pkg/registry/package.go @@ -1,7 +1,6 @@ package registry import ( - "path" "strings" ) @@ -42,7 +41,7 @@ func (p *Package) Path() string { return "" } - return stripVendorPath(p.pkg.Path()) + return p.pkg.Path() } var replacer = strings.NewReplacer( @@ -72,8 +71,6 @@ func (p Package) uniqueName(lvl int) string { return name } - - func min(a, b int) int { if a < b { return a diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 75a70c04..3cc978a7 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -8,7 +8,6 @@ import ( "sort" "strings" - "github.com/chigopher/pathlib" "golang.org/x/tools/go/packages" ) @@ -18,19 +17,19 @@ import ( // qualifiers. type Registry struct { srcPkgName string + srcPkgPath string srcPkgTypes *types.Package - outputPath *pathlib.Path aliases map[string]string imports map[string]*Package } // New loads the source package info and returns a new instance of // Registry. -func New(srcPkg *packages.Package, outputPath *pathlib.Path) (*Registry, error) { +func New(srcPkg *packages.Package) (*Registry, error) { return &Registry{ srcPkgName: srcPkg.Name, + srcPkgPath: srcPkg.PkgPath, srcPkgTypes: srcPkg.Types, - outputPath: outputPath, aliases: parseImportsAliases(srcPkg.Syntax), imports: make(map[string]*Package), }, nil @@ -79,8 +78,8 @@ func (r *Registry) MethodScope() *MethodScope { // suitable alias if there are any conflicts with previously imported // packages. func (r *Registry) AddImport(pkg *types.Package) *Package { - path := stripVendorPath(pkg.Path()) - if pathlib.NewPath(path).Equals(r.outputPath) { + path := pkg.Path() + if pkg.Path() == r.srcPkgPath { return nil } diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go index 48f91328..b2a276fb 100644 --- a/pkg/registry/registry_test.go +++ b/pkg/registry/registry_test.go @@ -1,11 +1 @@ package registry - -import ( - "testing" -) - -func BenchmarkNew(b *testing.B) { - for i := 0; i < b.N; i++ { - New("../../pkg/moq/testpackages/example", "") - } -} diff --git a/pkg/registry/var.go b/pkg/registry/var.go index 081a17c7..de8da2dc 100644 --- a/pkg/registry/var.go +++ b/pkg/registry/var.go @@ -30,7 +30,7 @@ func (v Var) TypeString() string { // packageQualifier is a types.Qualifier. func (v Var) packageQualifier(pkg *types.Package) string { - path := stripVendorPath(pkg.Path()) + path := pkg.Path() if v.moqPkgPath != "" && v.moqPkgPath == path { return "" } diff --git a/pkg/template_generator.go b/pkg/template_generator.go index 27da8568..b59c0f8b 100644 --- a/pkg/template_generator.go +++ b/pkg/template_generator.go @@ -1,39 +1,110 @@ package pkg import ( + "bytes" + "context" + "fmt" + "go/types" + + "github.com/chigopher/pathlib" + "github.com/rs/zerolog" "github.com/vektra/mockery/v2/pkg/config" + "github.com/vektra/mockery/v2/pkg/registry" "github.com/vektra/mockery/v2/pkg/template" + "golang.org/x/tools/go/packages" ) type TemplateGeneratorConfig struct { Style string } type TemplateGenerator struct { - config TemplateGeneratorConfig + config TemplateGeneratorConfig + registry *registry.Registry } -func NewTemplateGenerator(config TemplateGeneratorConfig) *TemplateGenerator { - return &TemplateGenerator{ - config: config, +func NewTemplateGenerator(srcPkg *packages.Package, config TemplateGeneratorConfig) (*TemplateGenerator, error) { + reg, err := registry.New(srcPkg) + if err != nil { + return nil, fmt.Errorf("creating new registry: %w", err) } + + return &TemplateGenerator{ + config: config, + registry: reg, + }, nil } -func (g *TemplateGenerator) Generate(iface *Interface, ifaceConfig *config.Config) error { - templ, err := template.New(g.config.Style) - if err != nil { - return err - } +func (g *TemplateGenerator) Generate(ctx context.Context, iface *Interface, ifaceConfig *config.Config) error { + log := zerolog.Ctx(ctx) + log.Info().Msg("generating mock for interface") + imports := Imports{} for _, method := range iface.Methods() { method.populateImports(imports) } - // TODO: Work on getting these imports into the template + methods := make([]template.MethodData, iface.ActualInterface.NumMethods()) + + for i := 0; i < iface.ActualInterface.NumMethods(); i++ { + method := iface.ActualInterface.Method(i) + methodScope := g.registry.MethodScope() + + signature := method.Type().(*types.Signature) + params := make([]template.ParamData, signature.Params().Len()) + for j := 0; j < signature.Params().Len(); j++ { + param := signature.Params().At(j) + params[j] = template.ParamData{ + Var: methodScope.AddVar(param, ""), + Variadic: signature.Variadic() && j == signature.Params().Len()-1, + } + } + + returns := make([]template.ParamData, signature.Results().Len()) + for j := 0; j < signature.Results().Len(); j++ { + param := signature.Results().At(j) + returns[j] = template.ParamData{ + Var: methodScope.AddVar(param, "Out"), + Variadic: false, + } + } + methods[i] = template.MethodData{ + Name: method.Name(), + Params: params, + Returns: returns, + } + + } + + // For now, mockery only supports one mock per file, which is why we're creating + // a single-element list. moq seems to have supported multiple mocks per file. + mockData := []template.MockData{ + { + InterfaceName: iface.Name, + MockName: ifaceConfig.MockName, + Methods: methods, + }, + } data := template.Data{ PkgName: ifaceConfig.Outpkg, SrcPkgQualifier: iface.Pkg.Name() + ".", - Imports: + Imports: g.registry.Imports(), + Mocks: mockData, } + templ, err := template.New(g.config.Style) + if err != nil { + return fmt.Errorf("creating new template: %w", err) + } + + var buf bytes.Buffer + if err := templ.Execute(&buf, data); err != nil { + return fmt.Errorf("executing template: %w", err) + } + + outPath := pathlib.NewPath(ifaceConfig.Dir).Join(ifaceConfig.FileName) + if err := outPath.WriteFile(buf.Bytes()); err != nil { + log.Error().Err(err).Msg("couldn't write to output file") + return fmt.Errorf("writing to output file: %w", err) + } return nil }