From dd7d6f75625488f72ac282df7053a70c9cc54e83 Mon Sep 17 00:00:00 2001 From: Ryan Moran Date: Wed, 4 Sep 2024 13:51:05 -0700 Subject: [PATCH] Adds initial generics support --- .../fixtures/fakes/generic_interface.go | 32 +++++++++ acceptance/fixtures/interfaces.go | 9 +++ acceptance/generate_test.go | 1 + parsing/argument.go | 13 +++- parsing/interface.go | 7 ++ parsing/interface_test.go | 33 +++++++++ parsing/signature.go | 4 +- parsing/type_param.go | 8 +++ rendering/context.go | 21 +++--- rendering/named_type.go | 70 ++++++++++++++----- rendering/type.go | 25 ++++--- rendering/util.go | 12 ++-- 12 files changed, 194 insertions(+), 41 deletions(-) create mode 100644 acceptance/fixtures/fakes/generic_interface.go create mode 100644 parsing/type_param.go diff --git a/acceptance/fixtures/fakes/generic_interface.go b/acceptance/fixtures/fakes/generic_interface.go new file mode 100644 index 0000000..716a667 --- /dev/null +++ b/acceptance/fixtures/fakes/generic_interface.go @@ -0,0 +1,32 @@ +package fakes + +import ( + "sync" + + "github.com/ryanmoran/faux/acceptance/fixtures" +) + +type GenericInterface[T comparable, S comparable] struct { + SomeMethodCall struct { + mutex sync.Mutex + CallCount int + Receives struct { + MapTS map[T]S + } + Returns struct { + ResultIntError fixtures.Result[int, error] + } + Stub func(map[T]S) fixtures.Result[int, error] + } +} + +func (f *GenericInterface[T, S]) SomeMethod(param1 map[T]S) fixtures.Result[int, error] { + f.SomeMethodCall.mutex.Lock() + defer f.SomeMethodCall.mutex.Unlock() + f.SomeMethodCall.CallCount++ + f.SomeMethodCall.Receives.MapTS = param1 + if f.SomeMethodCall.Stub != nil { + return f.SomeMethodCall.Stub(param1) + } + return f.SomeMethodCall.Returns.ResultIntError +} diff --git a/acceptance/fixtures/interfaces.go b/acceptance/fixtures/interfaces.go index f091ce9..e183a78 100644 --- a/acceptance/fixtures/interfaces.go +++ b/acceptance/fixtures/interfaces.go @@ -32,6 +32,15 @@ type NamedInterface interface { SomeMethod(someParam *bytes.Buffer) (someResult io.Reader) } +type Result[T, E any] struct { + Value T + Error E +} + +type GenericInterface[T, S comparable] interface { + SomeMethod(map[T]S) Result[int, error] +} + type BurntSushiParser struct { Key toml.Key } diff --git a/acceptance/generate_test.go b/acceptance/generate_test.go index 0864747..49f1b0c 100644 --- a/acceptance/generate_test.go +++ b/acceptance/generate_test.go @@ -61,6 +61,7 @@ var _ = Describe("faux", func() { Entry("variadic", "variadic_interface.go", "--file", "./fixtures/interfaces.go", "--interface", "VariadicInterface"), Entry("functions", "function_interface.go", "--file", "./fixtures/interfaces.go", "--interface", "FunctionInterface"), Entry("name", "named_interface.go", "--file", "./fixtures/interfaces.go", "--interface", "NamedInterface", "--name", "SomeNamedInterface"), + Entry("generic", "generic_interface.go", "--file", "./fixtures/interfaces.go", "--interface", "GenericInterface"), ) Context("when the source file is provided via an environment variable", func() { diff --git a/parsing/argument.go b/parsing/argument.go index 7e30ac8..96eab8d 100644 --- a/parsing/argument.go +++ b/parsing/argument.go @@ -7,13 +7,23 @@ import ( type Argument struct { Name string Type types.Type + TypeArgs []types.Type Variadic bool Package string } func NewArgument(v *types.Var, variadic bool) Argument { - var pkg string + var ( + pkg string + typeArgs []types.Type + ) + if t, ok := v.Type().(*types.Named); ok { + targs := t.TypeArgs() + for i := 0; i < targs.Len(); i++ { + typeArgs = append(typeArgs, targs.At(i)) + } + if t.Obj().Pkg() != nil { pkg = t.Obj().Pkg().Path() } @@ -22,6 +32,7 @@ func NewArgument(v *types.Var, variadic bool) Argument { return Argument{ Name: v.Name(), Type: v.Type(), + TypeArgs: typeArgs, Variadic: variadic, Package: pkg, } diff --git a/parsing/interface.go b/parsing/interface.go index 9f1e05c..5a43cc8 100644 --- a/parsing/interface.go +++ b/parsing/interface.go @@ -7,12 +7,18 @@ import ( type Interface struct { Name string + TypeArgs []*types.TypeParam Signatures []Signature } func NewInterface(n *types.Named) (Interface, error) { var signatures []Signature + var targs []*types.TypeParam + for i := 0; i < n.TypeParams().Len(); i++ { + targs = append(targs, n.TypeParams().At(i)) + } + underlying, ok := n.Underlying().(*types.Interface) if !ok { return Interface{}, fmt.Errorf("failed to load underlying type: %s is not an interface", n.Underlying()) @@ -24,6 +30,7 @@ func NewInterface(n *types.Named) (Interface, error) { return Interface{ Name: n.Obj().Name(), + TypeArgs: targs, Signatures: signatures, }, nil } diff --git a/parsing/interface_test.go b/parsing/interface_test.go index 673523d..aab8666 100644 --- a/parsing/interface_test.go +++ b/parsing/interface_test.go @@ -57,6 +57,39 @@ var _ = Describe("Interface", func() { }) }) + Context("when the interface has type params", func() { + var typeParam *types.TypeParam + + BeforeEach(func() { + signature := types.NewSignature(nil, nil, nil, false) + methods := []*types.Func{ + types.NewFunc(0, pkg, "SomeMethod", signature), + } + + underlying = types.NewInterfaceType(methods, nil).Complete() + namedType = types.NewNamed(typeName, underlying, nil) + + typeName := types.NewTypeName(0, pkg, "T", nil) + constraint := types.NewNamed(types.NewTypeName(0, nil, "any", nil), types.NewInterface(nil, nil), nil) + typeParam = types.NewTypeParam(typeName, constraint) + namedType.SetTypeParams([]*types.TypeParam{typeParam}) + }) + + It("includes those methods in the parsed interface", func() { + iface, err := parsing.NewInterface(namedType) + Expect(err).NotTo(HaveOccurred()) + Expect(iface).To(Equal(parsing.Interface{ + Name: "SomeType", + Signatures: []parsing.Signature{ + { + Name: "SomeMethod", + }, + }, + TypeArgs: []*types.TypeParam{typeParam}, + })) + }) + }) + Context("when the underlying type is not interface", func() { BeforeEach(func() { intType := types.Universe.Lookup("int").Type() diff --git a/parsing/signature.go b/parsing/signature.go index a0064e5..4eaad2b 100644 --- a/parsing/signature.go +++ b/parsing/signature.go @@ -1,6 +1,8 @@ package parsing -import "go/types" +import ( + "go/types" +) type Signature struct { Name string diff --git a/parsing/type_param.go b/parsing/type_param.go new file mode 100644 index 0000000..8103d6e --- /dev/null +++ b/parsing/type_param.go @@ -0,0 +1,8 @@ +package parsing + +import "go/types" + +type TypeParam struct { + Name string + Constraint types.Type +} diff --git a/rendering/context.go b/rendering/context.go index 6e2f038..e023756 100644 --- a/rendering/context.go +++ b/rendering/context.go @@ -35,7 +35,12 @@ func (c *Context) BuildFakeType(iface parsing.Interface) NamedType { calls = append(calls, c.BuildCallStruct(signature)) } - return NewNamedType(TitleString(iface.Name), NewStruct(calls)) + var targs []Type + for _, targ := range iface.TypeArgs { + targs = append(targs, NewType(targ, nil)) + } + + return NewNamedType(TitleString(iface.Name), NewStruct(calls), targs) } func (c *Context) BuildCallStruct(signature parsing.Signature) Field { @@ -58,7 +63,7 @@ func (c *Context) BuildCallStruct(signature parsing.Signature) Field { } func (c *Context) BuildMutex() Field { - return NewField("mutex", NewNamedType("sync.Mutex", NewStruct(nil))) + return NewField("mutex", NewNamedType("sync.Mutex", NewStruct(nil), nil)) } func (c *Context) BuildCallCount() Field { @@ -74,7 +79,7 @@ func (c *Context) BuildReceives(args []parsing.Argument) Field { } name = TitleString(name) - field := NewField(name, NewType(arg.Type)) + field := NewField(name, NewType(arg.Type, arg.TypeArgs)) fields = append(fields, field) } @@ -90,7 +95,7 @@ func (c *Context) BuildReturns(args []parsing.Argument) Field { } name = TitleString(name) - field := NewField(name, NewType(arg.Type)) + field := NewField(name, NewType(arg.Type, arg.TypeArgs)) fields = append(fields, field) } @@ -121,7 +126,7 @@ func (c *Context) BuildParams(args []parsing.Argument, named bool) []Param { name = ParamName(i) } - params = append(params, NewParam(name, NewType(arg.Type), arg.Variadic)) + params = append(params, NewParam(name, NewType(arg.Type, arg.TypeArgs), arg.Variadic)) } return params @@ -130,7 +135,7 @@ func (c *Context) BuildParams(args []parsing.Argument, named bool) []Param { func (c *Context) BuildResults(args []parsing.Argument) []Result { var results []Result for _, arg := range args { - results = append(results, NewResult(NewType(arg.Type))) + results = append(results, NewResult(NewType(arg.Type, arg.TypeArgs))) } return results @@ -143,7 +148,7 @@ func (c *Context) BuildBody(receiver Receiver, signature parsing.Signature) []St c.BuildIncrementStatement(receiver, signature.Name), } - for i, _ := range signature.Params { + for i := range signature.Params { statements = append(statements, c.BuildAssignStatement(receiver, signature.Name, i, signature.Params)) } @@ -195,7 +200,7 @@ func (c *Context) BuildAssignStatement(receiver Receiver, name string, index int paramField := receivesField.Type.(Struct).FieldWithName(argName) selector := NewSelector(receiver, callField, receivesField, paramField) paramName := ParamName(index) - param := NewParam(paramName, NewType(arg.Type), arg.Variadic) + param := NewParam(paramName, NewType(arg.Type, arg.TypeArgs), arg.Variadic) return NewAssignStatement(selector, param) } diff --git a/rendering/named_type.go b/rendering/named_type.go index 1b2e404..75c1ccc 100644 --- a/rendering/named_type.go +++ b/rendering/named_type.go @@ -6,38 +6,72 @@ import ( ) type NamedType struct { - Name string - Type Type + Name string + Type Type + TypeArgs []Type } -func NewNamedType(name string, t Type) NamedType { +func NewNamedType(name string, t Type, targTypes []Type) NamedType { return NamedType{ - Name: name, - Type: t, + Name: name, + Type: t, + TypeArgs: targTypes, } } -func NewDefinedType(name string) NamedType { - return NamedType{ - Name: name, - Type: Interface{}, - } +func NewDefinedType(name string, targTypes []Type) NamedType { + return NewNamedType(name, Interface{}, targTypes) } func (nt NamedType) Expr() ast.Expr { - return ast.NewIdent(nt.Name) + switch len(nt.TypeArgs) { + case 0: + return ast.NewIdent(nt.Name) + + case 1: + return &ast.IndexExpr{ + X: ast.NewIdent(nt.Name), + Index: nt.TypeArgs[0].Expr(), + } + + default: + var indices []ast.Expr + for _, typeArg := range nt.TypeArgs { + indices = append(indices, typeArg.Expr()) + } + + return &ast.IndexListExpr{ + X: ast.NewIdent(nt.Name), + Indices: indices, + } + } } func (nt NamedType) isType() {} func (nt NamedType) Decl() ast.Decl { + spec := &ast.TypeSpec{ + Name: ast.NewIdent(nt.Name), + Type: nt.Type.Expr(), + } + + if len(nt.TypeArgs) > 0 { + var fields []*ast.Field + for _, targ := range nt.TypeArgs { + ntarg := targ.(NamedType) + fields = append(fields, &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(ntarg.Name)}, + Type: ntarg.Type.Expr(), + }) + } + + spec.TypeParams = &ast.FieldList{ + List: fields, + } + } + return &ast.GenDecl{ - Tok: token.TYPE, - Specs: []ast.Spec{ - &ast.TypeSpec{ - Name: ast.NewIdent(nt.Name), - Type: nt.Type.Expr(), - }, - }, + Tok: token.TYPE, + Specs: []ast.Spec{spec}, } } diff --git a/rendering/type.go b/rendering/type.go index 8d5bdfa..5873cd5 100644 --- a/rendering/type.go +++ b/rendering/type.go @@ -10,10 +10,10 @@ type Type interface { isType() } -func NewType(t types.Type) Type { +func NewType(t types.Type, targs []types.Type) Type { switch s := t.(type) { case *types.Slice: - return NewSlice(NewType(s.Elem())) + return NewSlice(NewType(s.Elem(), nil)) case *types.Basic: return NewBasicType(s) @@ -26,17 +26,24 @@ func NewType(t types.Type) Type { if pkg != nil { name = fmt.Sprintf("%s.%s", pkg.Name(), obj.Name()) } + var targTypes []Type + for _, targ := range targs { + targTypes = append(targTypes, NewType(targ, nil)) + } + + return NewDefinedType(name, targTypes) - return NewDefinedType(name) + case *types.TypeParam: + return NewNamedType(s.String(), NewType(s.Constraint(), nil), nil) case *types.Interface: return Interface{} case *types.Pointer: - return NewPointer(NewType(s.Elem())) + return NewPointer(NewType(s.Elem(), nil)) case *types.Map: - return NewMap(NewType(s.Key()), NewType(s.Elem())) + return NewMap(NewType(s.Key(), nil), NewType(s.Elem(), nil)) case *types.Chan: var send, recv bool @@ -50,13 +57,13 @@ func NewType(t types.Type) Type { recv = true } - return NewChan(NewType(s.Elem()), send, recv) + return NewChan(NewType(s.Elem(), nil), send, recv) case *types.Struct: var fields []Field for i := 0; i < s.NumFields(); i++ { field := s.Field(i) - fields = append(fields, NewField(field.Name(), NewType(field.Type()))) + fields = append(fields, NewField(field.Name(), NewType(field.Type(), nil))) } return NewStruct(fields) @@ -65,13 +72,13 @@ func NewType(t types.Type) Type { var params []Param for i := 0; i < s.Params().Len(); i++ { param := s.Params().At(i) - params = append(params, NewParam("", NewType(param.Type()), false)) + params = append(params, NewParam("", NewType(param.Type(), nil), false)) } var results []Result for i := 0; i < s.Results().Len(); i++ { result := s.Results().At(i) - results = append(results, NewResult(NewType(result.Type()))) + results = append(results, NewResult(NewType(result.Type(), nil))) } return NewFunc(s.String(), Receiver{}, params, results, nil) diff --git a/rendering/util.go b/rendering/util.go index 37dea2a..48a9b93 100644 --- a/rendering/util.go +++ b/rendering/util.go @@ -39,8 +39,12 @@ func TypeName(t Type) string { return s.Underlying.String() case NamedType: + var targs []string + for _, targ := range s.TypeArgs { + targs = append(targs, TitleString(TypeName(targ))) + } parts := strings.Split(s.Name, ".") - return parts[len(parts)-1] + return strings.Join(append([]string{parts[len(parts)-1]}, targs...), "") case Pointer: return TypeName(s.Elem) @@ -59,13 +63,13 @@ func FieldTypeName(args []parsing.Argument, index int) string { nameCounts := map[string]int{} counter := map[string]int{} for _, arg := range args { - name := TypeName(NewType(arg.Type)) + name := TypeName(NewType(arg.Type, arg.TypeArgs)) nameCounts[name]++ } var indexedCounts []int for _, arg := range args { - name := TypeName(NewType(arg.Type)) + name := TypeName(NewType(arg.Type, arg.TypeArgs)) if nameCounts[name] > 1 { counter[name]++ indexedCounts = append(indexedCounts, counter[name]) @@ -74,7 +78,7 @@ func FieldTypeName(args []parsing.Argument, index int) string { } } - typeName := TypeName(NewType(args[index].Type)) + typeName := TypeName(NewType(args[index].Type, args[index].TypeArgs)) if indexedCounts[index] > 0 { typeName = fmt.Sprintf("%s_%d", typeName, indexedCounts[index]) }