Skip to content

Commit

Permalink
Hacks around elided parameter types
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Moran committed Mar 7, 2019
1 parent 45d7d06 commit 74621c9
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 30 deletions.
6 changes: 5 additions & 1 deletion gen/argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ type Argument struct {
Type string
}

func NewArgument(field *ast.Field, method string, fieldType types.Type, name string) Argument {
func NewArgument(fieldName string, field *ast.Field, method string, fieldType types.Type, name string) Argument {
if len(field.Names) > 0 {
name = field.Names[0].Name
}

if fieldName != "" {
name = fieldName
}

return Argument{
Method: method,
Name: strcase.ToLowerCamel(name),
Expand Down
2 changes: 1 addition & 1 deletion gen/argument_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var _ = Describe("Argument", func() {
fakeType := &FakeType{}
fakeType.StringCall.Returns.String = "SomeType"

Expect(gen.NewArgument(&ast.Field{
Expect(gen.NewArgument("", &ast.Field{
Names: []*ast.Ident{ast.NewIdent("SomeName")},
Type: ast.NewIdent("SomeType"),
}, "SomeMethod", fakeType, "fallbackName")).To(Equal(gen.Argument{
Expand Down
61 changes: 33 additions & 28 deletions gen/parse_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,11 @@ func parse(name string, typesInfo map[ast.Expr]types.TypeAndValue, files ...*ast
if funcType, ok := field.Type.(*ast.FuncType); ok {
methodName := field.Names[0].Name

var params []Argument
paramTypeCounts := map[string]int{}
for _, field := range funcType.Params.List {
fallbackName := types.ExprString(field.Type)
paramTypeCounts[fallbackName] = paramTypeCounts[fallbackName] + 1

if paramTypeCounts[fallbackName] > 1 {
fallbackName = fmt.Sprintf("%s%d", fallbackName, paramTypeCounts[fallbackName])
}

params = append(params, NewArgument(field, methodName, typesInfo[field.Type].Type, fallbackName))
}

var results []Argument
resultTypeCounts := map[string]int{}
for _, field := range funcType.Results.List {
fallbackName := types.ExprString(field.Type)
resultTypeCounts[fallbackName] = resultTypeCounts[fallbackName] + 1

if resultTypeCounts[fallbackName] > 1 {
fallbackName = fmt.Sprintf("%s%d", fallbackName, resultTypeCounts[fallbackName])
}

results = append(results, NewArgument(field, methodName, typesInfo[field.Type].Type, fallbackName))
}

methods = append(methods, Method{
Name: methodName,
Receiver: typeSpec.Name.Name,
Params: params,
Results: results,
Params: parseArguments(methodName, typesInfo, funcType.Params.List),
Results: parseArguments(methodName, typesInfo, funcType.Results.List),
})
}
}
Expand All @@ -103,3 +77,34 @@ func parse(name string, typesInfo map[ast.Expr]types.TypeAndValue, files ...*ast

return Fake{}, false, nil
}

func parseArguments(methodName string, typesInfo map[ast.Expr]types.TypeAndValue, fields []*ast.Field) []Argument {
argTypeCounts := map[string]int{}

var args []Argument
for _, field := range fields {
fallbackName := types.ExprString(field.Type)

if len(field.Names) > 1 {
for _, fieldName := range field.Names {
argTypeCounts[fallbackName] = argTypeCounts[fallbackName] + 1

if argTypeCounts[fallbackName] > 1 {
fallbackName = fmt.Sprintf("%s%d", fallbackName, argTypeCounts[fallbackName])
}

args = append(args, NewArgument(types.ExprString(fieldName), field, methodName, typesInfo[field.Type].Type, fallbackName))
}
} else {
argTypeCounts[fallbackName] = argTypeCounts[fallbackName] + 1

if argTypeCounts[fallbackName] > 1 {
fallbackName = fmt.Sprintf("%s%d", fallbackName, argTypeCounts[fallbackName])
}

args = append(args, NewArgument("", field, methodName, typesInfo[field.Type].Type, fallbackName))
}
}

return args
}
49 changes: 49 additions & 0 deletions gen/parse_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,55 @@ type SomeInterface interface{
})
})

Context("when the types are elided", func() {
It("parses the given file, returning a fake matching the given named interface", func() {
source := strings.NewReader(`package main
import "io"
type SomeInterface interface{
SomeMethod(someParam1, someParam2 string) (int, io.Reader)
}
`)

fake, err := gen.ParseFile("some-file.go", source, "SomeInterface")
Expect(err).NotTo(HaveOccurred())
Expect(fake).To(Equal(gen.Fake{
Name: "SomeInterface",
Methods: []gen.Method{
{
Name: "SomeMethod",
Receiver: "SomeInterface",
Params: []gen.Argument{
{
Method: "SomeMethod",
Name: "someParam1",
Type: "string",
},
{
Method: "SomeMethod",
Name: "someParam2",
Type: "string",
},
},
Results: []gen.Argument{
{
Method: "SomeMethod",
Name: "int",
Type: "int",
},
{
Method: "SomeMethod",
Name: "ioReader",
Type: "io.Reader",
},
},
},
},
}))
})
})

Context("failure cases", func() {
Context("when the source file cannot be parsed", func() {
It("returns an error", func() {
Expand Down

0 comments on commit 74621c9

Please sign in to comment.