diff --git a/gen/argument.go b/gen/argument.go index f7ef0f3..bc93dc1 100644 --- a/gen/argument.go +++ b/gen/argument.go @@ -15,11 +15,15 @@ type Argument struct { Type string } -func NewArgument(field *ast.Field, method string, fieldType types.Type, name string) Argument { +func NewArgument(fieldName string, field *ast.Field, method string, fieldType types.Type, name string) Argument { if len(field.Names) > 0 { name = field.Names[0].Name } + if fieldName != "" { + name = fieldName + } + return Argument{ Method: method, Name: strcase.ToLowerCamel(name), diff --git a/gen/argument_test.go b/gen/argument_test.go index 79906f3..ba760a5 100644 --- a/gen/argument_test.go +++ b/gen/argument_test.go @@ -38,7 +38,7 @@ var _ = Describe("Argument", func() { fakeType := &FakeType{} fakeType.StringCall.Returns.String = "SomeType" - Expect(gen.NewArgument(&ast.Field{ + Expect(gen.NewArgument("", &ast.Field{ Names: []*ast.Ident{ast.NewIdent("SomeName")}, Type: ast.NewIdent("SomeType"), }, "SomeMethod", fakeType, "fallbackName")).To(Equal(gen.Argument{ diff --git a/gen/parse_file.go b/gen/parse_file.go index 3973352..3bc1db0 100644 --- a/gen/parse_file.go +++ b/gen/parse_file.go @@ -54,37 +54,11 @@ func parse(name string, typesInfo map[ast.Expr]types.TypeAndValue, files ...*ast if funcType, ok := field.Type.(*ast.FuncType); ok { methodName := field.Names[0].Name - var params []Argument - paramTypeCounts := map[string]int{} - for _, field := range funcType.Params.List { - fallbackName := types.ExprString(field.Type) - paramTypeCounts[fallbackName] = paramTypeCounts[fallbackName] + 1 - - if paramTypeCounts[fallbackName] > 1 { - fallbackName = fmt.Sprintf("%s%d", fallbackName, paramTypeCounts[fallbackName]) - } - - params = append(params, NewArgument(field, methodName, typesInfo[field.Type].Type, fallbackName)) - } - - var results []Argument - resultTypeCounts := map[string]int{} - for _, field := range funcType.Results.List { - fallbackName := types.ExprString(field.Type) - resultTypeCounts[fallbackName] = resultTypeCounts[fallbackName] + 1 - - if resultTypeCounts[fallbackName] > 1 { - fallbackName = fmt.Sprintf("%s%d", fallbackName, resultTypeCounts[fallbackName]) - } - - results = append(results, NewArgument(field, methodName, typesInfo[field.Type].Type, fallbackName)) - } - methods = append(methods, Method{ Name: methodName, Receiver: typeSpec.Name.Name, - Params: params, - Results: results, + Params: parseArguments(methodName, typesInfo, funcType.Params.List), + Results: parseArguments(methodName, typesInfo, funcType.Results.List), }) } } @@ -103,3 +77,34 @@ func parse(name string, typesInfo map[ast.Expr]types.TypeAndValue, files ...*ast return Fake{}, false, nil } + +func parseArguments(methodName string, typesInfo map[ast.Expr]types.TypeAndValue, fields []*ast.Field) []Argument { + argTypeCounts := map[string]int{} + + var args []Argument + for _, field := range fields { + fallbackName := types.ExprString(field.Type) + + if len(field.Names) > 1 { + for _, fieldName := range field.Names { + argTypeCounts[fallbackName] = argTypeCounts[fallbackName] + 1 + + if argTypeCounts[fallbackName] > 1 { + fallbackName = fmt.Sprintf("%s%d", fallbackName, argTypeCounts[fallbackName]) + } + + args = append(args, NewArgument(types.ExprString(fieldName), field, methodName, typesInfo[field.Type].Type, fallbackName)) + } + } else { + argTypeCounts[fallbackName] = argTypeCounts[fallbackName] + 1 + + if argTypeCounts[fallbackName] > 1 { + fallbackName = fmt.Sprintf("%s%d", fallbackName, argTypeCounts[fallbackName]) + } + + args = append(args, NewArgument("", field, methodName, typesInfo[field.Type].Type, fallbackName)) + } + } + + return args +} diff --git a/gen/parse_file_test.go b/gen/parse_file_test.go index b6f42da..f19db3f 100644 --- a/gen/parse_file_test.go +++ b/gen/parse_file_test.go @@ -112,6 +112,55 @@ type SomeInterface interface{ }) }) + Context("when the types are elided", func() { + It("parses the given file, returning a fake matching the given named interface", func() { + source := strings.NewReader(`package main + +import "io" + +type SomeInterface interface{ + SomeMethod(someParam1, someParam2 string) (int, io.Reader) +} +`) + + fake, err := gen.ParseFile("some-file.go", source, "SomeInterface") + Expect(err).NotTo(HaveOccurred()) + Expect(fake).To(Equal(gen.Fake{ + Name: "SomeInterface", + Methods: []gen.Method{ + { + Name: "SomeMethod", + Receiver: "SomeInterface", + Params: []gen.Argument{ + { + Method: "SomeMethod", + Name: "someParam1", + Type: "string", + }, + { + Method: "SomeMethod", + Name: "someParam2", + Type: "string", + }, + }, + Results: []gen.Argument{ + { + Method: "SomeMethod", + Name: "int", + Type: "int", + }, + { + Method: "SomeMethod", + Name: "ioReader", + Type: "io.Reader", + }, + }, + }, + }, + })) + }) + }) + Context("failure cases", func() { Context("when the source file cannot be parsed", func() { It("returns an error", func() {