diff --git a/.gitignore b/.gitignore index f3f60b9..d3f7522 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ mock_display_test.go +mock_generic_display_test.go debug.test .vscode *.coverprofile diff --git a/dsl_test.go b/dsl_test.go index 42a4bbf..cb666e1 100644 --- a/dsl_test.go +++ b/dsl_test.go @@ -963,6 +963,24 @@ var _ = Describe("MockDisplay", func() { }) }) +var _ = Describe("GenericMockDisplay", func() { + var display *MockGenericDisplay[string, int64] + + BeforeEach(func() { + display = NewMockGenericDisplay[string, int64]() + }) + + Context("Stubbing generic method with generic value", func() { + BeforeEach(func() { + When(display.GenericParams(map[string]int64{"Hello": 333})).ThenReturn(int64(666)) + }) + + It("returns stubbed generic value", func() { + Expect(display.GenericParams(map[string]int64{"Hello": 333})).To(Equal(int64(666))) + }) + }) +}) + func flattenStringSliceOfSlices(sliceOfSlices [][]string) (result []string) { for _, slice := range sliceOfSlices { result = append(result, slice...) diff --git a/generate_test_mocks/xtools_go_loader/generate_test.go b/generate_test_mocks/xtools_go_loader/generate_test.go index d40c55e..2401614 100644 --- a/generate_test_mocks/xtools_go_loader/generate_test.go +++ b/generate_test_mocks/xtools_go_loader/generate_test.go @@ -23,7 +23,7 @@ import ( ) func TestMockGeneration(t *testing.T) { - RunSpecs(t, "Generating mocks with golang.org/x/tools/go/loader") + RunSpecs(t, "Generating mocks with golang.org/x/tools/go/packages") } var _ = It("Generate mocks", func() { @@ -31,4 +31,9 @@ var _ = It("Generate mocks", func() { []string{"github.com/petergtz/pegomock/v3/test_interface", "Display"}, "../../mock_display_test.go", "MockDisplay", "pegomock_test", "", false, os.Stdout, true, true, "") + + filehandling.GenerateMockFile( + []string{"github.com/petergtz/pegomock/v3/test_interface", "GenericDisplay"}, + "../../mock_generic_display_test.go", "MockGenericDisplay", "pegomock_test", + "", false, os.Stdout, true, false, "") }) diff --git a/go.mod b/go.mod index 0d827dd..3dbf6c9 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/xhit/go-str2duration/v2 v2.1.0 // indirect + golang.org/x/mod v0.10.0 // indirect golang.org/x/net v0.9.0 // indirect golang.org/x/sys v0.7.0 // indirect golang.org/x/text v0.9.0 // indirect diff --git a/go.sum b/go.sum index 6a9e3ab..dbf309b 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,10 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc= github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index b96da1a..5b699c4 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -149,9 +149,12 @@ func sanitize(s string) string { } func (g *generator) generateMockFor(iface *model.Interface, mockTypeName, selfPackage string) { - g.generateMockType(mockTypeName) + typeParamNames := typeParamsStringFrom(iface.TypeParams, g.packageMap, selfPackage, false) + typeParams := typeParamsStringFrom(iface.TypeParams, g.packageMap, selfPackage, true) + g.generateMockType(mockTypeName, typeParams, + typeParamNames) for _, method := range iface.Methods { - g.generateMockMethod(mockTypeName, method, selfPackage) + g.generateMockMethod(mockTypeName, typeParamNames, method, selfPackage) g.emptyLine() addTypesFromMethodParamsTo(g.typesSet, method.In, g.packageMap) @@ -160,42 +163,59 @@ func (g *generator) generateMockFor(iface *model.Interface, mockTypeName, selfPa addTypesFromMethodParamsTo(g.typesSet, []*model.Parameter{method.Variadic}, g.packageMap) } } - g.generateMockVerifyMethods(mockTypeName) - g.generateVerifierType(mockTypeName) + g.generateMockVerifyMethods(mockTypeName, typeParamNames) + g.generateVerifierType(mockTypeName, typeParams, typeParamNames) for _, method := range iface.Methods { ongoingVerificationTypeName := fmt.Sprintf("%v_%v_OngoingVerification", mockTypeName, method.Name) args, argNames, argTypes, _ := argDataFor(method, g.packageMap, selfPackage) - g.generateVerifierMethod(mockTypeName, method, selfPackage, ongoingVerificationTypeName, args, argNames) - g.generateOngoingVerificationType(mockTypeName, ongoingVerificationTypeName) - g.generateOngoingVerificationGetCapturedArguments(ongoingVerificationTypeName, argNames, argTypes) - g.generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationTypeName, argTypes, method.Variadic != nil) + g.generateVerifierMethod(mockTypeName, typeParamNames, method, selfPackage, ongoingVerificationTypeName, args, argNames) + g.generateOngoingVerificationType(mockTypeName, typeParams, typeParamNames, ongoingVerificationTypeName) + g.generateOngoingVerificationGetCapturedArguments(ongoingVerificationTypeName, argNames, argTypes, typeParamNames) + g.generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationTypeName, typeParamNames, argTypes, method.Variadic != nil) } } -func (g *generator) generateMockType(mockTypeName string) { +func typeParamsStringFrom(params []*model.Parameter, packageMap map[string]string, pkgOverride string, withTypes bool) string { + if len(params) == 0 { + return "" + } + result := "[" + for i, param := range params { + if i > 0 { + result += ", " + } + result += param.Name + if withTypes { + result += " " + param.Type.String(packageMap, pkgOverride) + } + } + return result + "]" +} + +func (g *generator) generateMockType(mockTypeName string, typeParams string, typeParamNames string) { g. emptyLine(). - p("type %v struct {", mockTypeName). + p("type %v%v struct {", mockTypeName, typeParams). p(" fail func(message string, callerSkip ...int)"). p("}"). emptyLine(). - p("func New%v(options ...pegomock.Option) *%v {", mockTypeName, mockTypeName). - p(" mock := &%v{}", mockTypeName). + p("func New%v%v(options ...pegomock.Option) *%v%v {", mockTypeName, typeParams, mockTypeName, typeParamNames). + p(" mock := &%v%v{}", mockTypeName, typeParamNames). p(" for _, option := range options {"). p(" option.Apply(mock)"). p(" }"). p(" return mock"). p("}"). emptyLine(). - p("func (mock *%v) SetFailHandler(fh pegomock.FailHandler) { mock.fail = fh }", mockTypeName). - p("func (mock *%v) FailHandler() pegomock.FailHandler { return mock.fail }", mockTypeName). + p("func (mock *%v%v) SetFailHandler(fh pegomock.FailHandler) { mock.fail = fh }", mockTypeName, typeParamNames). + p("func (mock *%v%v) FailHandler() pegomock.FailHandler { return mock.fail }", mockTypeName, typeParamNames). emptyLine() } // If non-empty, pkgOverride is the package in which unqualified types reside. -func (g *generator) generateMockMethod(mockType string, method *model.Method, pkgOverride string) *generator { +func (g *generator) generateMockMethod(mockType string, typeParamNames string, method *model.Method, pkgOverride string) *generator { args, argNames, _, returnTypes := argDataFor(method, g.packageMap, pkgOverride) - g.p("func (mock *%v) %v(%v) (%v) {", mockType, method.Name, join(args), join(stringSliceFrom(returnTypes, g.packageMap, pkgOverride))) + g.p("func (mock *%v%v) %v(%v) (%v) {", mockType, typeParamNames, method.Name, join(args), join(stringSliceFrom(returnTypes, g.packageMap, pkgOverride))) g.p("if mock == nil {"). p(" panic(\"mock must not be nil. Use myMock := New%v().\")", mockType). p("}") @@ -240,10 +260,10 @@ func (g *generator) generateMockMethod(mockType string, method *model.Method, pk return g } -func (g *generator) generateVerifierType(interfaceName string) *generator { +func (g *generator) generateVerifierType(interfaceName string, typeParams string, typeParamNames string) *generator { return g. - p("type Verifier%v struct {", interfaceName). - p(" mock *%v", interfaceName). + p("type Verifier%v%v struct {", interfaceName, typeParams). + p(" mock *%v%v", interfaceName, typeParamNames). p(" invocationCountMatcher pegomock.InvocationCountMatcher"). p(" inOrderContext *pegomock.InOrderContext"). p(" timeout time.Duration"). @@ -251,32 +271,32 @@ func (g *generator) generateVerifierType(interfaceName string) *generator { emptyLine() } -func (g *generator) generateMockVerifyMethods(interfaceName string) { +func (g *generator) generateMockVerifyMethods(interfaceName string, typeParamNames string) { g. - p("func (mock *%v) VerifyWasCalledOnce() *Verifier%v {", interfaceName, interfaceName). - p(" return &Verifier%v{", interfaceName). + p("func (mock *%v%v) VerifyWasCalledOnce() *Verifier%v%v {", interfaceName, typeParamNames, interfaceName, typeParamNames). + p(" return &Verifier%v%v{", interfaceName, typeParamNames). p(" mock: mock,"). p(" invocationCountMatcher: pegomock.Times(1),"). p(" }"). p("}"). emptyLine(). - p("func (mock *%v) VerifyWasCalled(invocationCountMatcher pegomock.InvocationCountMatcher) *Verifier%v {", interfaceName, interfaceName). - p(" return &Verifier%v{", interfaceName). + p("func (mock *%v%v) VerifyWasCalled(invocationCountMatcher pegomock.InvocationCountMatcher) *Verifier%v%v {", interfaceName, typeParamNames, interfaceName, typeParamNames). + p(" return &Verifier%v%v{", interfaceName, typeParamNames). p(" mock: mock,"). p(" invocationCountMatcher: invocationCountMatcher,"). p(" }"). p("}"). emptyLine(). - p("func (mock *%v) VerifyWasCalledInOrder(invocationCountMatcher pegomock.InvocationCountMatcher, inOrderContext *pegomock.InOrderContext) *Verifier%v {", interfaceName, interfaceName). - p(" return &Verifier%v{", interfaceName). + p("func (mock *%v%v) VerifyWasCalledInOrder(invocationCountMatcher pegomock.InvocationCountMatcher, inOrderContext *pegomock.InOrderContext) *Verifier%v%v {", interfaceName, typeParamNames, interfaceName, typeParamNames). + p(" return &Verifier%v%v{", interfaceName, typeParamNames). p(" mock: mock,"). p(" invocationCountMatcher: invocationCountMatcher,"). p(" inOrderContext: inOrderContext,"). p(" }"). p("}"). emptyLine(). - p("func (mock *%v) VerifyWasCalledEventually(invocationCountMatcher pegomock.InvocationCountMatcher, timeout time.Duration) *Verifier%v {", interfaceName, interfaceName). - p(" return &Verifier%v{", interfaceName). + p("func (mock *%v%v) VerifyWasCalledEventually(invocationCountMatcher pegomock.InvocationCountMatcher, timeout time.Duration) *Verifier%v%v {", interfaceName, typeParamNames, interfaceName, typeParamNames). + p(" return &Verifier%v%v{", interfaceName, typeParamNames). p(" mock: mock,"). p(" invocationCountMatcher: invocationCountMatcher,"). p(" timeout: timeout,"). @@ -285,12 +305,12 @@ func (g *generator) generateMockVerifyMethods(interfaceName string) { emptyLine() } -func (g *generator) generateVerifierMethod(interfaceName string, method *model.Method, pkgOverride string, returnTypeString string, args []string, argNames []string) *generator { +func (g *generator) generateVerifierMethod(interfaceName string, typeParamNames string, method *model.Method, pkgOverride string, returnTypeString string, args []string, argNames []string) *generator { return g. - p("func (verifier *Verifier%v) %v(%v) *%v {", interfaceName, method.Name, join(args), returnTypeString). + p("func (verifier *Verifier%v%v) %v(%v) *%v%v {", interfaceName, typeParamNames, method.Name, join(args), returnTypeString, typeParamNames). GenerateParamsDeclaration(argNames, method.Variadic != nil). p("methodInvocations := pegomock.GetGenericMockFrom(verifier.mock).Verify(verifier.inOrderContext, verifier.invocationCountMatcher, \"%v\", params, verifier.timeout)", method.Name). - p("return &%v{mock: verifier.mock, methodInvocations: methodInvocations}", returnTypeString). + p("return &%v%v{mock: verifier.mock, methodInvocations: methodInvocations}", returnTypeString, typeParamNames). p("}") } @@ -306,17 +326,17 @@ func (g *generator) GenerateParamsDeclaration(argNames []string, isVariadic bool } } -func (g *generator) generateOngoingVerificationType(interfaceName string, ongoingVerificationStructName string) *generator { +func (g *generator) generateOngoingVerificationType(interfaceName string, typeParams string, typeParamNames string, ongoingVerificationStructName string) *generator { return g. - p("type %v struct {", ongoingVerificationStructName). - p("mock *%v", interfaceName). + p("type %v%v struct {", ongoingVerificationStructName, typeParams). + p("mock *%v%v", interfaceName, typeParamNames). p(" methodInvocations []pegomock.MethodInvocation"). p("}"). emptyLine() } -func (g *generator) generateOngoingVerificationGetCapturedArguments(ongoingVerificationStructName string, argNames []string, argTypes []string) *generator { - g.p("func (c *%v) GetCapturedArguments() (%v) {", ongoingVerificationStructName, join(argTypes)) +func (g *generator) generateOngoingVerificationGetCapturedArguments(ongoingVerificationStructName string, argNames []string, argTypes []string, typeParamNames string) *generator { + g.p("func (c *%v%v) GetCapturedArguments() (%v) {", ongoingVerificationStructName, typeParamNames, join(argTypes)) if len(argNames) > 0 { indexedArgNames := make([]string, len(argNames)) for i, argName := range argNames { @@ -330,12 +350,12 @@ func (g *generator) generateOngoingVerificationGetCapturedArguments(ongoingVerif return g } -func (g *generator) generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationStructName string, argTypes []string, isVariadic bool) *generator { +func (g *generator) generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationStructName string, typeParamNames string, argTypes []string, isVariadic bool) *generator { argsAsArray := make([]string, len(argTypes)) for i, argType := range argTypes { argsAsArray[i] = fmt.Sprintf("_param%v []%v", i, argType) } - g.p("func (c *%v) GetAllCapturedArguments() (%v) {", ongoingVerificationStructName, strings.Join(argsAsArray, ", ")) + g.p("func (c *%v%v) GetAllCapturedArguments() (%v) {", ongoingVerificationStructName, typeParamNames, strings.Join(argsAsArray, ", ")) if len(argTypes) > 0 { g.p("params := pegomock.GetGenericMockFrom(c.mock).GetInvocationParams(c.methodInvocations)") g.p("if len(params) > 0 {") diff --git a/mockgen/mockgen_test.go b/mockgen/mockgen_test.go index 8766c20..e847c90 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/mockgen_test.go @@ -2,7 +2,7 @@ package mockgen_test import ( "github.com/petergtz/pegomock/v3/mockgen" - "github.com/petergtz/pegomock/v3/modelgen/loader" + "github.com/petergtz/pegomock/v3/modelgen/xtools_packages" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -11,7 +11,7 @@ import ( var _ = Describe("Mockgen", func() { Context("matcherSourceCodes", func() { It("uses correct naming pattern with underscores for keys, and correct types etc. in source code", func() { - ast, e := loader.GenerateModel("github.com/petergtz/pegomock/v3/test_interface", "Display") + ast, e := xtools_packages.GenerateModel("github.com/petergtz/pegomock/v3/test_interface", "Display") Expect(e).NotTo(HaveOccurred()) _, matcherSourceCodes := mockgen.GenerateOutput(ast, "irrelevant", "MockDisplay", "test_package", "") diff --git a/model/model.go b/model/model.go index e09de71..4f7dcd3 100644 --- a/model/model.go +++ b/model/model.go @@ -46,8 +46,9 @@ func (pkg *Package) Imports() map[string]bool { // Interface is a Go interface. type Interface struct { - Name string - Methods []*Method + Name string + TypeParams []*Parameter + Methods []*Method } func (intf *Interface) Print(w io.Writer) { @@ -58,6 +59,9 @@ func (intf *Interface) Print(w io.Writer) { } func (intf *Interface) addImports(im map[string]bool) { + for _, tp := range intf.TypeParams { + tp.Type.addImports(im) + } for _, m := range intf.Methods { m.addImports(im) } diff --git a/modelgen/gomock/parse.go b/modelgen/gomock/parse.go index 95d97fd..b9eb09a 100644 --- a/modelgen/gomock/parse.go +++ b/modelgen/gomock/parse.go @@ -23,6 +23,7 @@ import ( "go/parser" "go/token" "log" + "os" "path" "strconv" "strings" @@ -38,6 +39,8 @@ var ( // TODO: simplify error reporting func ParseFile(source string) (*model.Package, error) { + fmt.Fprintln(os.Stderr, "WARNING: The gomock package is deprecated and will be removed in a future version.") + fs := token.NewFileSet() file, err := parser.ParseFile(fs, source, nil, 0) if err != nil { diff --git a/modelgen/modelgen_test.go b/modelgen/modelgen_test.go index 7655b35..f3495d5 100644 --- a/modelgen/modelgen_test.go +++ b/modelgen/modelgen_test.go @@ -21,7 +21,7 @@ import ( "github.com/petergtz/pegomock/v3/model" "github.com/petergtz/pegomock/v3/modelgen/gomock" - "github.com/petergtz/pegomock/v3/modelgen/loader" + "github.com/petergtz/pegomock/v3/modelgen/xtools_packages" "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2" @@ -40,13 +40,13 @@ func (a alphabetically) Len() int { return len(a) } func (a alphabetically) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a alphabetically) Less(i, j int) bool { return a[i].Name < a[j].Name } -var _ = Describe("modelgen/loader", func() { +var _ = Describe("xtools_packages", func() { It("generates an equivalent model as gomock/reflect does", func() { pkgFromReflect, e := gomock.Reflect("github.com/petergtz/pegomock/v3/test_interface", []string{"Display"}) Expect(e).NotTo(HaveOccurred()) sort.Sort(alphabetically(pkgFromReflect.Interfaces[0].Methods)) - pkgFromLoader, e := loader.GenerateModel("github.com/petergtz/pegomock/v3/test_interface", "Display") + pkgFromLoader, e := xtools_packages.GenerateModel("github.com/petergtz/pegomock/v3/test_interface", "Display") Expect(e).NotTo(HaveOccurred()) sort.Sort(alphabetically(pkgFromLoader.Interfaces[0].Methods)) @@ -60,7 +60,7 @@ var _ = Describe("modelgen/loader", func() { }) It("generates a model with the basic properties", func() { - pkg, e := loader.GenerateModel("github.com/petergtz/pegomock/v3/modelgen/test_data/default_test_interface", "Display") + pkg, e := xtools_packages.GenerateModel("github.com/petergtz/pegomock/v3/modelgen/test_data/default_test_interface", "Display") Expect(e).NotTo(HaveOccurred()) Expect(pkg.Name).To(Equal("test_interface")) diff --git a/modelgen/loader/loader.go b/modelgen/xtools_loader/loader.go similarity index 98% rename from modelgen/loader/loader.go rename to modelgen/xtools_loader/loader.go index a51eb46..e645141 100644 --- a/modelgen/loader/loader.go +++ b/modelgen/xtools_loader/loader.go @@ -1,4 +1,4 @@ -package loader +package xtools_loader import ( "errors" @@ -11,6 +11,8 @@ import ( ) func GenerateModel(importPath string, interfaceName string) (*model.Package, error) { + panic("DEPRECATED: Use GenerateModelViaPackages instead") + var conf loader.Config conf.Import(importPath) program, e := conf.Load() diff --git a/modelgen/loader/loader_suite_test.go b/modelgen/xtools_loader/loader_suite_test.go similarity index 68% rename from modelgen/loader/loader_suite_test.go rename to modelgen/xtools_loader/loader_suite_test.go index 35072dc..65520f7 100644 --- a/modelgen/loader/loader_suite_test.go +++ b/modelgen/xtools_loader/loader_suite_test.go @@ -1,4 +1,4 @@ -package loader_test +package xtools_loader_test import ( "testing" @@ -9,5 +9,5 @@ import ( func TestLoader(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "Loader Suite") + RunSpecs(t, "xtools_loader_test Suite") } diff --git a/modelgen/loader/loader_test.go b/modelgen/xtools_loader/loader_test.go similarity index 88% rename from modelgen/loader/loader_test.go rename to modelgen/xtools_loader/loader_test.go index 2d6b51b..31f57ee 100644 --- a/modelgen/loader/loader_test.go +++ b/modelgen/xtools_loader/loader_test.go @@ -1,13 +1,13 @@ -package loader_test +package xtools_loader_test import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - . "github.com/petergtz/pegomock/v3/modelgen/loader" + . "github.com/petergtz/pegomock/v3/modelgen/xtools_loader" ) -var _ = Describe("Loader", func() { +var _ = XDescribe("Loader", func() { Describe("GenerateModel", func() { It("finds all methods within interface", func() { pkg, e := GenerateModel("io", "Reader") diff --git a/modelgen/xtools_packages/packages.go b/modelgen/xtools_packages/packages.go new file mode 100644 index 0000000..dc32554 --- /dev/null +++ b/modelgen/xtools_packages/packages.go @@ -0,0 +1,193 @@ +package xtools_packages + +import ( + "errors" + "fmt" + "go/types" + "path" + + "github.com/petergtz/pegomock/v3/model" + "golang.org/x/tools/go/packages" +) + +type Bla[K comparable, V Number] interface { + SumNumbers(m map[K]V, i int, s string, a []float32, sss ...string) V +} + +type Number interface { + int64 | float64 +} + +type Blub[K comparable, V Number] struct{} + +func (b *Blub[K, V]) SumNumbers(m map[K]V, i int, q string, a []float32) V { + var s V + for _, v := range m { + s += v + } + return s +} + +func NewBlub[K comparable, V Number]() *Blub[K, V] { + return &Blub[K, V]{} +} + +func GenerateModel(importPath string, interfaceName string) (*model.Package, error) { + + pkgs, e := packages.Load(&packages.Config{Mode: packages.NeedTypes}, importPath) + if e != nil { + return nil, e + } + for _, pkg := range pkgs { + scope := pkg.Types.Scope() + obj := scope.Lookup(interfaceName) + if obj == nil { + continue + } + // from here, things follow the spec in https://tip.golang.org/ref/spec + if typeName, isTypeName := obj.(*types.TypeName); isTypeName { + if iface, isIface := typeName.Type().Underlying().(*types.Interface); isIface { + + g := &modelGenerator2{typeParams: make(map[string]*model.Parameter)} + methods := g.modelMethodsFrom(iface) + + return &model.Package{ + Name: path.Base(pkg.Types.Name()), + Interfaces: []*model.Interface{{ + Name: interfaceName, + Methods: methods, + TypeParams: sliceFrom(g.typeParams), + }}, + }, nil + } + } + } + + return nil, errors.New("Did not find interface name \"" + interfaceName + "\"") +} + +type modelGenerator2 struct { + typeParams map[string]*model.Parameter +} + +func (g *modelGenerator2) modelMethodsFrom(iface *types.Interface) (modelMethods []*model.Method) { + for i := 0; i < iface.NumMethods(); i++ { + modelMethods = append(modelMethods, g.modelMethodFrom(iface.Method(i))) + } + return +} + +func (g *modelGenerator2) modelMethodFrom(method *types.Func) *model.Method { + signature := method.Type().(*types.Signature) + in, variadic := g.inParamsFrom(signature) + return &model.Method{ + Name: method.Name(), + In: in, + Variadic: variadic, + Out: g.outParamsFrom(signature), + } +} + +func (g *modelGenerator2) inParamsFrom(signature *types.Signature) (in []*model.Parameter, variadic *model.Parameter) { + for u := 0; u < signature.Params().Len(); u++ { + if signature.Variadic() && u == signature.Params().Len()-1 { + variadic = &model.Parameter{ + Name: signature.Params().At(u).Name(), + Type: g.modelTypeFrom(signature.Params().At(u).Type().(*types.Slice).Elem()), + } + break + } + in = append(in, &model.Parameter{ + Name: signature.Params().At(u).Name(), + Type: g.modelTypeFrom(signature.Params().At(u).Type()), + }) + } + return +} + +func (g *modelGenerator2) outParamsFrom(signature *types.Signature) (out []*model.Parameter) { + if signature.Results() != nil { + for u := 0; u < signature.Results().Len(); u++ { + out = append(out, &model.Parameter{ + Name: signature.Results().At(u).Name(), + Type: g.modelTypeFrom(signature.Results().At(u).Type()), + }) + } + } + return +} + +func (g *modelGenerator2) modelTypeFrom(typesType types.Type) model.Type { + switch typedTyp := typesType.(type) { + case *types.Basic: + if !predeclared(typedTyp.Kind()) { + panic(fmt.Sprintf("Unexpected Basic Type %v", typedTyp.Name())) + } + return model.PredeclaredType(typedTyp.Name()) + case *types.Pointer: + return &model.PointerType{ + Type: g.modelTypeFrom(typedTyp.Elem()), + } + case *types.Array: + return &model.ArrayType{ + Len: int(typedTyp.Len()), + Type: g.modelTypeFrom(typedTyp.Elem()), + } + case *types.Slice: + return &model.ArrayType{ + Len: -1, + Type: g.modelTypeFrom(typedTyp.Elem()), + } + case *types.Map: + return &model.MapType{ + Key: g.modelTypeFrom(typedTyp.Key()), + Value: g.modelTypeFrom(typedTyp.Elem()), + } + case *types.Chan: + var dir model.ChanDir + switch typedTyp.Dir() { + case types.SendOnly: + dir = model.SendDir + case types.RecvOnly: + dir = model.RecvDir + default: + dir = 0 + } + return &model.ChanType{ + Dir: dir, + Type: g.modelTypeFrom(typedTyp.Elem()), + } + case *types.Named: + if typedTyp.Obj().Pkg() == nil { + return model.PredeclaredType(typedTyp.Obj().Name()) + } + return &model.NamedType{ + Package: typedTyp.Obj().Pkg().Path(), + Type: typedTyp.Obj().Name(), + } + case *types.Interface, *types.Struct: + return model.PredeclaredType(typedTyp.String()) + case *types.Signature: + in, variadic := g.inParamsFrom(typedTyp) + return &model.FuncType{In: in, Out: g.outParamsFrom(typedTyp), Variadic: variadic} + case *types.TypeParam: + g.typeParams[typedTyp.Obj().Name()] = &model.Parameter{ + Name: typedTyp.Obj().Name(), + Type: g.modelTypeFrom(typedTyp.Constraint()), + } + return model.PredeclaredType(typedTyp.Obj().Name()) + default: + panic(fmt.Sprintf("Unknown types.Type: %v (%T)", typesType, typesType)) + } +} + +func sliceFrom(params map[string]*model.Parameter) (result []*model.Parameter) { + for _, v := range params { + result = append(result, v) + } + return +} + +func predeclared(basicKind types.BasicKind) bool { + return basicKind >= types.Bool && basicKind <= types.String +} diff --git a/modelgen/xtools_packages/packages_suite_test.go b/modelgen/xtools_packages/packages_suite_test.go new file mode 100644 index 0000000..5f2e03e --- /dev/null +++ b/modelgen/xtools_packages/packages_suite_test.go @@ -0,0 +1,13 @@ +package xtools_packages_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestLoader(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "xtools_packages_test Suite") +} diff --git a/modelgen/xtools_packages/packages_test.go b/modelgen/xtools_packages/packages_test.go new file mode 100644 index 0000000..cfa5841 --- /dev/null +++ b/modelgen/xtools_packages/packages_test.go @@ -0,0 +1,65 @@ +package xtools_packages_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/petergtz/pegomock/v3/model" + + . "github.com/petergtz/pegomock/v3/modelgen/xtools_packages" +) + +var _ = Describe("Packages", func() { + Describe("GenerateModel", func() { + + It("finds all methods within interface", func() { + pkg, e := GenerateModel("io", "Reader") + Expect(e).NotTo(HaveOccurred()) + Expect(pkg.Interfaces).To(HaveLen(1)) + Expect(pkg.Interfaces[0].Name).To(Equal("Reader")) + Expect(pkg.Interfaces[0].Methods).To(HaveLen(1)) + Expect(pkg.Interfaces[0].Methods[0].Name).To(Equal("Read")) + }) + + Context("using an interface with embedded interfaces", func() { + It("finds all methods", func() { + pkg, e := GenerateModel("io", "ReadCloser") + Expect(e).NotTo(HaveOccurred()) + Expect(pkg.Interfaces).To(HaveLen(1)) + Expect(pkg.Interfaces[0].Name).To(Equal("ReadCloser")) + Expect(pkg.Interfaces[0].Methods).To(HaveLen(2)) + Expect([]string{pkg.Interfaces[0].Methods[0].Name, pkg.Interfaces[0].Methods[1].Name}).To( + ConsistOf("Read", "Close")) + }) + }) + + It("finds correct generic parameters in an interface", func() { + pkg, e := GenerateModel("io", "Reader") + Expect(e).NotTo(HaveOccurred()) + Expect(pkg.Interfaces).To(HaveLen(1)) + Expect(pkg.Interfaces[0].Name).To(Equal("Reader")) + Expect(pkg.Interfaces[0].Methods).To(HaveLen(1)) + Expect(pkg.Interfaces[0].Methods[0].Name).To(Equal("Read")) + + pkg, e = GenerateModel("github.com/petergtz/pegomock/v3/modelgen/xtools_packages", "Bla") + Expect(e).NotTo(HaveOccurred()) + Expect(pkg.Interfaces).To(HaveLen(1)) + Expect(pkg.Interfaces[0].Name).To(Equal("Bla")) + Expect(pkg.Interfaces[0].TypeParams).To(ConsistOf( + &model.Parameter{ + Name: "K", + Type: model.PredeclaredType("comparable"), + }, + &model.Parameter{ + Name: "V", + Type: &model.NamedType{ + Package: "github.com/petergtz/pegomock/v3/modelgen/xtools_packages", + Type: "Number", + }, + }, + )) + Expect(pkg.Interfaces[0].Methods).To(HaveLen(1)) + Expect(pkg.Interfaces[0].Methods[0].Name).To(Equal("SumNumbers")) + }) + + }) +}) diff --git a/pegomock/filehandling/filehandling.go b/pegomock/filehandling/filehandling.go index 78660cf..5dc92c3 100644 --- a/pegomock/filehandling/filehandling.go +++ b/pegomock/filehandling/filehandling.go @@ -12,7 +12,7 @@ import ( "github.com/petergtz/pegomock/v3/mockgen" "github.com/petergtz/pegomock/v3/model" "github.com/petergtz/pegomock/v3/modelgen/gomock" - "github.com/petergtz/pegomock/v3/modelgen/loader" + "github.com/petergtz/pegomock/v3/modelgen/xtools_packages" "github.com/petergtz/pegomock/v3/pegomock/util" ) @@ -99,8 +99,7 @@ func GenerateMockSourceCode(args []string, nameOut string, packageOut string, se log.Fatal("Expected exactly two arguments, but got " + fmt.Sprint(args)) } if useExperimentalModelGen { - ast, err = loader.GenerateModel(args[0], args[1]) - + ast, err = xtools_packages.GenerateModel(args[0], args[1]) } else { ast, err = gomock.Reflect(args[0], strings.Split(args[1], ",")) } diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index d5561ec..48dd6e8 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -14,7 +14,8 @@ rm -rf matchers ginkgo -succinct generate_test_mocks/gomock_reflect ginkgo --skip-package=pegomock/watch --randomize-all --randomize-suites --race --trace -cover -rm -f mock_display_test.go -rm -rf matchers -ginkgo -succinct generate_test_mocks/gomock_source -ginkgo --skip-package=pegomock/watch --randomize-all --randomize-suites --race --trace -cover +# DEPRECATED: gomock_source is deprecated and will be removed in a future release. +#rm -f mock_display_test.go +#rm -rf matchers +#ginkgo -succinct generate_test_mocks/gomock_source +#ginkgo --skip-package=pegomock/watch --randomize-all --randomize-suites --race --trace -cover diff --git a/test_interface/display.go b/test_interface/display.go index 073f3c2..35132a3 100644 --- a/test_interface/display.go +++ b/test_interface/display.go @@ -48,3 +48,11 @@ type Display interface { MapWithRedundantImports(m map[http.File]http.File) MapOfStringToEmptyUnnamedStruct(m map[string]struct{}) } + +type GenericDisplay[N comparable, V Number] interface { + GenericParams(m map[N]V) V +} + +type Number interface { + int64 | float64 +}