Skip to content

Commit

Permalink
Introduce support for generics
Browse files Browse the repository at this point in the history
  • Loading branch information
petergtz committed May 9, 2023
1 parent f7f6595 commit 223960f
Show file tree
Hide file tree
Showing 19 changed files with 395 additions and 60 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mock_display_test.go
mock_generic_display_test.go
debug.test
.vscode
*.coverprofile
Expand Down
18 changes: 18 additions & 0 deletions dsl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,24 @@ var _ = Describe("MockDisplay", func() {
})
})

var _ = Describe("GenericMockDisplay", func() {
var display *MockGenericDisplay[string, int64]

BeforeEach(func() {
display = NewMockGenericDisplay[string, int64]()
})

Context("Stubbing generic method with generic value", func() {
BeforeEach(func() {
When(display.GenericParams(map[string]int64{"Hello": 333})).ThenReturn(int64(666))
})

It("returns stubbed generic value", func() {
Expect(display.GenericParams(map[string]int64{"Hello": 333})).To(Equal(int64(666)))
})
})
})

func flattenStringSliceOfSlices(sliceOfSlices [][]string) (result []string) {
for _, slice := range sliceOfSlices {
result = append(result, slice...)
Expand Down
7 changes: 6 additions & 1 deletion generate_test_mocks/xtools_go_loader/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ import (
)

func TestMockGeneration(t *testing.T) {
RunSpecs(t, "Generating mocks with golang.org/x/tools/go/loader")
RunSpecs(t, "Generating mocks with golang.org/x/tools/go/packages")
}

var _ = It("Generate mocks", func() {
filehandling.GenerateMockFile(
[]string{"github.com/petergtz/pegomock/v3/test_interface", "Display"},
"../../mock_display_test.go", "MockDisplay", "pegomock_test",
"", false, os.Stdout, true, true, "")

filehandling.GenerateMockFile(
[]string{"github.com/petergtz/pegomock/v3/test_interface", "GenericDisplay"},
"../../mock_generic_display_test.go", "MockGenericDisplay", "pegomock_test",
"", false, os.Stdout, true, false, "")
})
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require (
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
golang.org/x/mod v0.10.0 // indirect
golang.org/x/net v0.9.0 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/text v0.9.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down
96 changes: 58 additions & 38 deletions mockgen/mockgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,12 @@ func sanitize(s string) string {
}

func (g *generator) generateMockFor(iface *model.Interface, mockTypeName, selfPackage string) {
g.generateMockType(mockTypeName)
typeParamNames := typeParamsStringFrom(iface.TypeParams, g.packageMap, selfPackage, false)
typeParams := typeParamsStringFrom(iface.TypeParams, g.packageMap, selfPackage, true)
g.generateMockType(mockTypeName, typeParams,
typeParamNames)
for _, method := range iface.Methods {
g.generateMockMethod(mockTypeName, method, selfPackage)
g.generateMockMethod(mockTypeName, typeParamNames, method, selfPackage)
g.emptyLine()

addTypesFromMethodParamsTo(g.typesSet, method.In, g.packageMap)
Expand All @@ -160,42 +163,59 @@ func (g *generator) generateMockFor(iface *model.Interface, mockTypeName, selfPa
addTypesFromMethodParamsTo(g.typesSet, []*model.Parameter{method.Variadic}, g.packageMap)
}
}
g.generateMockVerifyMethods(mockTypeName)
g.generateVerifierType(mockTypeName)
g.generateMockVerifyMethods(mockTypeName, typeParamNames)
g.generateVerifierType(mockTypeName, typeParams, typeParamNames)
for _, method := range iface.Methods {
ongoingVerificationTypeName := fmt.Sprintf("%v_%v_OngoingVerification", mockTypeName, method.Name)
args, argNames, argTypes, _ := argDataFor(method, g.packageMap, selfPackage)
g.generateVerifierMethod(mockTypeName, method, selfPackage, ongoingVerificationTypeName, args, argNames)
g.generateOngoingVerificationType(mockTypeName, ongoingVerificationTypeName)
g.generateOngoingVerificationGetCapturedArguments(ongoingVerificationTypeName, argNames, argTypes)
g.generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationTypeName, argTypes, method.Variadic != nil)
g.generateVerifierMethod(mockTypeName, typeParamNames, method, selfPackage, ongoingVerificationTypeName, args, argNames)
g.generateOngoingVerificationType(mockTypeName, typeParams, typeParamNames, ongoingVerificationTypeName)
g.generateOngoingVerificationGetCapturedArguments(ongoingVerificationTypeName, argNames, argTypes, typeParamNames)
g.generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationTypeName, typeParamNames, argTypes, method.Variadic != nil)
}
}

func (g *generator) generateMockType(mockTypeName string) {
func typeParamsStringFrom(params []*model.Parameter, packageMap map[string]string, pkgOverride string, withTypes bool) string {
if len(params) == 0 {
return ""
}
result := "["
for i, param := range params {
if i > 0 {
result += ", "
}
result += param.Name
if withTypes {
result += " " + param.Type.String(packageMap, pkgOverride)
}
}
return result + "]"
}

func (g *generator) generateMockType(mockTypeName string, typeParams string, typeParamNames string) {
g.
emptyLine().
p("type %v struct {", mockTypeName).
p("type %v%v struct {", mockTypeName, typeParams).
p(" fail func(message string, callerSkip ...int)").
p("}").
emptyLine().
p("func New%v(options ...pegomock.Option) *%v {", mockTypeName, mockTypeName).
p(" mock := &%v{}", mockTypeName).
p("func New%v%v(options ...pegomock.Option) *%v%v {", mockTypeName, typeParams, mockTypeName, typeParamNames).
p(" mock := &%v%v{}", mockTypeName, typeParamNames).
p(" for _, option := range options {").
p(" option.Apply(mock)").
p(" }").
p(" return mock").
p("}").
emptyLine().
p("func (mock *%v) SetFailHandler(fh pegomock.FailHandler) { mock.fail = fh }", mockTypeName).
p("func (mock *%v) FailHandler() pegomock.FailHandler { return mock.fail }", mockTypeName).
p("func (mock *%v%v) SetFailHandler(fh pegomock.FailHandler) { mock.fail = fh }", mockTypeName, typeParamNames).
p("func (mock *%v%v) FailHandler() pegomock.FailHandler { return mock.fail }", mockTypeName, typeParamNames).
emptyLine()
}

// If non-empty, pkgOverride is the package in which unqualified types reside.
func (g *generator) generateMockMethod(mockType string, method *model.Method, pkgOverride string) *generator {
func (g *generator) generateMockMethod(mockType string, typeParamNames string, method *model.Method, pkgOverride string) *generator {
args, argNames, _, returnTypes := argDataFor(method, g.packageMap, pkgOverride)
g.p("func (mock *%v) %v(%v) (%v) {", mockType, method.Name, join(args), join(stringSliceFrom(returnTypes, g.packageMap, pkgOverride)))
g.p("func (mock *%v%v) %v(%v) (%v) {", mockType, typeParamNames, method.Name, join(args), join(stringSliceFrom(returnTypes, g.packageMap, pkgOverride)))
g.p("if mock == nil {").
p(" panic(\"mock must not be nil. Use myMock := New%v().\")", mockType).
p("}")
Expand Down Expand Up @@ -240,43 +260,43 @@ func (g *generator) generateMockMethod(mockType string, method *model.Method, pk
return g
}

func (g *generator) generateVerifierType(interfaceName string) *generator {
func (g *generator) generateVerifierType(interfaceName string, typeParams string, typeParamNames string) *generator {
return g.
p("type Verifier%v struct {", interfaceName).
p(" mock *%v", interfaceName).
p("type Verifier%v%v struct {", interfaceName, typeParams).
p(" mock *%v%v", interfaceName, typeParamNames).
p(" invocationCountMatcher pegomock.InvocationCountMatcher").
p(" inOrderContext *pegomock.InOrderContext").
p(" timeout time.Duration").
p("}").
emptyLine()
}

func (g *generator) generateMockVerifyMethods(interfaceName string) {
func (g *generator) generateMockVerifyMethods(interfaceName string, typeParamNames string) {
g.
p("func (mock *%v) VerifyWasCalledOnce() *Verifier%v {", interfaceName, interfaceName).
p(" return &Verifier%v{", interfaceName).
p("func (mock *%v%v) VerifyWasCalledOnce() *Verifier%v%v {", interfaceName, typeParamNames, interfaceName, typeParamNames).
p(" return &Verifier%v%v{", interfaceName, typeParamNames).
p(" mock: mock,").
p(" invocationCountMatcher: pegomock.Times(1),").
p(" }").
p("}").
emptyLine().
p("func (mock *%v) VerifyWasCalled(invocationCountMatcher pegomock.InvocationCountMatcher) *Verifier%v {", interfaceName, interfaceName).
p(" return &Verifier%v{", interfaceName).
p("func (mock *%v%v) VerifyWasCalled(invocationCountMatcher pegomock.InvocationCountMatcher) *Verifier%v%v {", interfaceName, typeParamNames, interfaceName, typeParamNames).
p(" return &Verifier%v%v{", interfaceName, typeParamNames).
p(" mock: mock,").
p(" invocationCountMatcher: invocationCountMatcher,").
p(" }").
p("}").
emptyLine().
p("func (mock *%v) VerifyWasCalledInOrder(invocationCountMatcher pegomock.InvocationCountMatcher, inOrderContext *pegomock.InOrderContext) *Verifier%v {", interfaceName, interfaceName).
p(" return &Verifier%v{", interfaceName).
p("func (mock *%v%v) VerifyWasCalledInOrder(invocationCountMatcher pegomock.InvocationCountMatcher, inOrderContext *pegomock.InOrderContext) *Verifier%v%v {", interfaceName, typeParamNames, interfaceName, typeParamNames).
p(" return &Verifier%v%v{", interfaceName, typeParamNames).
p(" mock: mock,").
p(" invocationCountMatcher: invocationCountMatcher,").
p(" inOrderContext: inOrderContext,").
p(" }").
p("}").
emptyLine().
p("func (mock *%v) VerifyWasCalledEventually(invocationCountMatcher pegomock.InvocationCountMatcher, timeout time.Duration) *Verifier%v {", interfaceName, interfaceName).
p(" return &Verifier%v{", interfaceName).
p("func (mock *%v%v) VerifyWasCalledEventually(invocationCountMatcher pegomock.InvocationCountMatcher, timeout time.Duration) *Verifier%v%v {", interfaceName, typeParamNames, interfaceName, typeParamNames).
p(" return &Verifier%v%v{", interfaceName, typeParamNames).
p(" mock: mock,").
p(" invocationCountMatcher: invocationCountMatcher,").
p(" timeout: timeout,").
Expand All @@ -285,12 +305,12 @@ func (g *generator) generateMockVerifyMethods(interfaceName string) {
emptyLine()
}

func (g *generator) generateVerifierMethod(interfaceName string, method *model.Method, pkgOverride string, returnTypeString string, args []string, argNames []string) *generator {
func (g *generator) generateVerifierMethod(interfaceName string, typeParamNames string, method *model.Method, pkgOverride string, returnTypeString string, args []string, argNames []string) *generator {
return g.
p("func (verifier *Verifier%v) %v(%v) *%v {", interfaceName, method.Name, join(args), returnTypeString).
p("func (verifier *Verifier%v%v) %v(%v) *%v%v {", interfaceName, typeParamNames, method.Name, join(args), returnTypeString, typeParamNames).
GenerateParamsDeclaration(argNames, method.Variadic != nil).
p("methodInvocations := pegomock.GetGenericMockFrom(verifier.mock).Verify(verifier.inOrderContext, verifier.invocationCountMatcher, \"%v\", params, verifier.timeout)", method.Name).
p("return &%v{mock: verifier.mock, methodInvocations: methodInvocations}", returnTypeString).
p("return &%v%v{mock: verifier.mock, methodInvocations: methodInvocations}", returnTypeString, typeParamNames).
p("}")
}

Expand All @@ -306,17 +326,17 @@ func (g *generator) GenerateParamsDeclaration(argNames []string, isVariadic bool
}
}

func (g *generator) generateOngoingVerificationType(interfaceName string, ongoingVerificationStructName string) *generator {
func (g *generator) generateOngoingVerificationType(interfaceName string, typeParams string, typeParamNames string, ongoingVerificationStructName string) *generator {
return g.
p("type %v struct {", ongoingVerificationStructName).
p("mock *%v", interfaceName).
p("type %v%v struct {", ongoingVerificationStructName, typeParams).
p("mock *%v%v", interfaceName, typeParamNames).
p(" methodInvocations []pegomock.MethodInvocation").
p("}").
emptyLine()
}

func (g *generator) generateOngoingVerificationGetCapturedArguments(ongoingVerificationStructName string, argNames []string, argTypes []string) *generator {
g.p("func (c *%v) GetCapturedArguments() (%v) {", ongoingVerificationStructName, join(argTypes))
func (g *generator) generateOngoingVerificationGetCapturedArguments(ongoingVerificationStructName string, argNames []string, argTypes []string, typeParamNames string) *generator {
g.p("func (c *%v%v) GetCapturedArguments() (%v) {", ongoingVerificationStructName, typeParamNames, join(argTypes))
if len(argNames) > 0 {
indexedArgNames := make([]string, len(argNames))
for i, argName := range argNames {
Expand All @@ -330,12 +350,12 @@ func (g *generator) generateOngoingVerificationGetCapturedArguments(ongoingVerif
return g
}

func (g *generator) generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationStructName string, argTypes []string, isVariadic bool) *generator {
func (g *generator) generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationStructName string, typeParamNames string, argTypes []string, isVariadic bool) *generator {
argsAsArray := make([]string, len(argTypes))
for i, argType := range argTypes {
argsAsArray[i] = fmt.Sprintf("_param%v []%v", i, argType)
}
g.p("func (c *%v) GetAllCapturedArguments() (%v) {", ongoingVerificationStructName, strings.Join(argsAsArray, ", "))
g.p("func (c *%v%v) GetAllCapturedArguments() (%v) {", ongoingVerificationStructName, typeParamNames, strings.Join(argsAsArray, ", "))
if len(argTypes) > 0 {
g.p("params := pegomock.GetGenericMockFrom(c.mock).GetInvocationParams(c.methodInvocations)")
g.p("if len(params) > 0 {")
Expand Down
4 changes: 2 additions & 2 deletions mockgen/mockgen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package mockgen_test

import (
"github.com/petergtz/pegomock/v3/mockgen"
"github.com/petergtz/pegomock/v3/modelgen/loader"
"github.com/petergtz/pegomock/v3/modelgen/xtools_packages"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand All @@ -11,7 +11,7 @@ import (
var _ = Describe("Mockgen", func() {
Context("matcherSourceCodes", func() {
It("uses correct naming pattern with underscores for keys, and correct types etc. in source code", func() {
ast, e := loader.GenerateModel("github.com/petergtz/pegomock/v3/test_interface", "Display")
ast, e := xtools_packages.GenerateModel("github.com/petergtz/pegomock/v3/test_interface", "Display")
Expect(e).NotTo(HaveOccurred())
_, matcherSourceCodes := mockgen.GenerateOutput(ast, "irrelevant", "MockDisplay", "test_package", "")

Expand Down
8 changes: 6 additions & 2 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ func (pkg *Package) Imports() map[string]bool {

// Interface is a Go interface.
type Interface struct {
Name string
Methods []*Method
Name string
TypeParams []*Parameter
Methods []*Method
}

func (intf *Interface) Print(w io.Writer) {
Expand All @@ -58,6 +59,9 @@ func (intf *Interface) Print(w io.Writer) {
}

func (intf *Interface) addImports(im map[string]bool) {
for _, tp := range intf.TypeParams {
tp.Type.addImports(im)
}
for _, m := range intf.Methods {
m.addImports(im)
}
Expand Down
3 changes: 3 additions & 0 deletions modelgen/gomock/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"go/parser"
"go/token"
"log"
"os"
"path"
"strconv"
"strings"
Expand All @@ -38,6 +39,8 @@ var (
// TODO: simplify error reporting

func ParseFile(source string) (*model.Package, error) {
fmt.Fprintln(os.Stderr, "WARNING: The gomock package is deprecated and will be removed in a future version.")

fs := token.NewFileSet()
file, err := parser.ParseFile(fs, source, nil, 0)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions modelgen/modelgen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (

"github.com/petergtz/pegomock/v3/model"
"github.com/petergtz/pegomock/v3/modelgen/gomock"
"github.com/petergtz/pegomock/v3/modelgen/loader"
"github.com/petergtz/pegomock/v3/modelgen/xtools_packages"

"github.com/onsi/ginkgo/v2"
. "github.com/onsi/ginkgo/v2"
Expand All @@ -40,13 +40,13 @@ func (a alphabetically) Len() int { return len(a) }
func (a alphabetically) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a alphabetically) Less(i, j int) bool { return a[i].Name < a[j].Name }

var _ = Describe("modelgen/loader", func() {
var _ = Describe("xtools_packages", func() {
It("generates an equivalent model as gomock/reflect does", func() {
pkgFromReflect, e := gomock.Reflect("github.com/petergtz/pegomock/v3/test_interface", []string{"Display"})
Expect(e).NotTo(HaveOccurred())
sort.Sort(alphabetically(pkgFromReflect.Interfaces[0].Methods))

pkgFromLoader, e := loader.GenerateModel("github.com/petergtz/pegomock/v3/test_interface", "Display")
pkgFromLoader, e := xtools_packages.GenerateModel("github.com/petergtz/pegomock/v3/test_interface", "Display")
Expect(e).NotTo(HaveOccurred())
sort.Sort(alphabetically(pkgFromLoader.Interfaces[0].Methods))

Expand All @@ -60,7 +60,7 @@ var _ = Describe("modelgen/loader", func() {
})

It("generates a model with the basic properties", func() {
pkg, e := loader.GenerateModel("github.com/petergtz/pegomock/v3/modelgen/test_data/default_test_interface", "Display")
pkg, e := xtools_packages.GenerateModel("github.com/petergtz/pegomock/v3/modelgen/test_data/default_test_interface", "Display")
Expect(e).NotTo(HaveOccurred())

Expect(pkg.Name).To(Equal("test_interface"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package loader
package xtools_loader

import (
"errors"
Expand All @@ -11,6 +11,8 @@ import (
)

func GenerateModel(importPath string, interfaceName string) (*model.Package, error) {
panic("DEPRECATED: Use GenerateModelViaPackages instead")

var conf loader.Config
conf.Import(importPath)
program, e := conf.Load()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package loader_test
package xtools_loader_test

import (
"testing"
Expand All @@ -9,5 +9,5 @@ import (

func TestLoader(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Loader Suite")
RunSpecs(t, "xtools_loader_test Suite")
}
Loading

0 comments on commit 223960f

Please sign in to comment.