Skip to content

Commit

Permalink
Merge pull request #36 from csueiras/fix-file-writer
Browse files Browse the repository at this point in the history
FIX issue #34 and #35
  • Loading branch information
csueiras authored Feb 23, 2021
2 parents f1d5023 + 96a4b7a commit c3f98c2
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 40 deletions.
7 changes: 1 addition & 6 deletions internal/generator/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,7 @@ func (e *Executor) Execute(settings *Parameters) (*generator.Generated, error) {
return nil, fmt.Errorf("multiple types with same name discovered with name %s", typName)
}
results[typName] = res.InterfaceType

cfg = append(cfg, &generator.FileConfig{
SrcTypeName: typName,
OutTypeName: typName,
InterfaceType: res.InterfaceType,
})
cfg = append(cfg, generator.NewFileConfig(typName, typName, res.InterfaceType))
}
}

Expand Down
49 changes: 29 additions & 20 deletions internal/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,29 @@ var fileHeader = "Code generated by reinforcer, DO NOT EDIT."

// FileConfig holds the code generation configuration for a specific type
type FileConfig struct {
// SrcTypeName is the source type that we want to generate code for
SrcTypeName string
// OutTypeName is the desired output type name
OutTypeName string
// InterfaceType holds the type information for SrcTypeName
InterfaceType *types.Interface
// srcTypeName is the source type that we want to generate code for
srcTypeName string
// outTypeName is the desired output type name
outTypeName string
// interfaceType holds the type information for SrcTypeName
interfaceType *types.Interface
}

// NewFileConfig creates a new instance of the FileConfig which holds code generation configuration
func NewFileConfig(srcTypeName string, outTypeName string, interfaceType *types.Interface) *FileConfig {
return &FileConfig{
srcTypeName: strings.Title(srcTypeName),
outTypeName: strings.Title(outTypeName),
interfaceType: interfaceType,
}
}

func (f *FileConfig) targetName() string {
return "target" + f.SrcTypeName
return "target" + f.srcTypeName
}

func (f *FileConfig) receiverName() string {
return strings.ToLower(f.OutTypeName[0:1])
return strings.ToLower(f.outTypeName[0:1])
}

// Config holds the code generation configuration for all of desired types
Expand Down Expand Up @@ -88,7 +97,7 @@ func Generate(cfg Config) (*Generated, error) {
var fileMethods []*fileMeta

for _, fileConfig := range cfg.Files {
methods, err := parseMethods(fileConfig.OutTypeName, fileConfig.InterfaceType)
methods, err := parseMethods(fileConfig.outTypeName, fileConfig.interfaceType)
if err != nil {
return nil, err
}
Expand All @@ -97,7 +106,7 @@ func Generate(cfg Config) (*Generated, error) {
return nil, err
}
gen.Files = append(gen.Files, &GeneratedFile{
TypeName: fileConfig.OutTypeName,
TypeName: fileConfig.outTypeName,
Contents: s,
})
fileMethods = append(fileMethods, &fileMeta{fileConfig: fileConfig, methods: methods})
Expand Down Expand Up @@ -128,17 +137,17 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig
))

// Declare the proxy implementation
f.Add(jen.Type().Id(fileCfg.OutTypeName).Struct(
f.Add(jen.Type().Id(fileCfg.outTypeName).Struct(
jen.Op("*").Id("base"),
jen.Id("delegate").Id(fileCfg.targetName()),
))

// Declare the ctor
f.Add(jen.Func().Id("New"+fileCfg.SrcTypeName).Params(
f.Add(jen.Func().Id("New"+fileCfg.outTypeName).Params(
jen.Id("delegate").Id(fileCfg.targetName()),
jen.Id("runnerFactory").Id("runnerFactory"),
jen.Id("options").Op("...").Id("Option"),
).Op("*").Id(fileCfg.OutTypeName).Block(
).Op("*").Id(fileCfg.outTypeName).Block(
// if delegate == nil
jen.If(jen.Id("delegate").Op("==").Nil().Block(
// panic("...")
Expand All @@ -150,7 +159,7 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig
jen.Panic(jen.Lit("provided nil runner factory")),
)),
// c:= &OutTypeName{...}
jen.Id("c").Op(":=").Add(jen.Op("&").Id(fileCfg.OutTypeName).Values(jen.Dict{
jen.Id("c").Op(":=").Add(jen.Op("&").Id(fileCfg.outTypeName).Values(jen.Dict{
// embed the base struct
jen.Id("base"): jen.Op("&").Id("base").Values(jen.Dict{
jen.Id("errorPredicate"): jen.Id("RetryAllErrors"),
Expand All @@ -168,7 +177,7 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig
// Declare all of our proxy methods
for _, mm := range methods {
if mm.ReturnsError {
r := retryable.NewRetryable(mm, fileCfg.OutTypeName, fileCfg.receiverName())
r := retryable.NewRetryable(mm, fileCfg.outTypeName, fileCfg.receiverName())
s, err := r.Statement()
if err != nil {
return "", err
Expand All @@ -177,9 +186,9 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig
} else {
var p statement
if ignoreNoReturnMethods {
p = passthrough.NewPassThrough(mm, fileCfg.OutTypeName, fileCfg.receiverName())
p = passthrough.NewPassThrough(mm, fileCfg.outTypeName, fileCfg.receiverName())
} else {
p = noret.NewNoReturn(mm, fileCfg.OutTypeName, fileCfg.receiverName())
p = noret.NewNoReturn(mm, fileCfg.outTypeName, fileCfg.receiverName())
}
s, err := p.Statement()
if err != nil {
Expand Down Expand Up @@ -244,9 +253,9 @@ func generateConstants(outPkg string, meta []*fileMeta) (string, error) {
constantAssign = append(constantAssign, jen.Id(m.Name).Op(":").Lit(m.Name).Op(","))
}

constObjName := fmt.Sprintf("%sMethods", fm.fileConfig.OutTypeName)
log.Debug().Msgf("Adding constants for type %s", fm.fileConfig.OutTypeName)
f.Add(jen.Comment(fmt.Sprintf("%s are the methods in %s", constObjName, fm.fileConfig.OutTypeName)))
constObjName := fmt.Sprintf("%sMethods", fm.fileConfig.outTypeName)
log.Debug().Msgf("Adding constants for type %s", fm.fileConfig.outTypeName)
f.Add(jen.Comment(fmt.Sprintf("%s are the methods in %s", constObjName, fm.fileConfig.outTypeName)))
f.Add(
jen.Var().Id(constObjName).Op("=").Struct(
fields...,
Expand Down
135 changes: 125 additions & 10 deletions internal/generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/packages/packagestest"
"strings"
"testing"
)

Expand Down Expand Up @@ -100,7 +101,7 @@ type GeneratedService struct {
delegate targetService
}
func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
if delegate == nil {
panic("provided nil delegate")
}
Expand Down Expand Up @@ -243,7 +244,7 @@ type GeneratedService struct {
delegate targetService
}
func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
if delegate == nil {
panic("provided nil delegate")
}
Expand Down Expand Up @@ -428,7 +429,7 @@ type GeneratedService struct {
delegate targetService
}
func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
if delegate == nil {
panic("provided nil delegate")
}
Expand Down Expand Up @@ -553,7 +554,7 @@ type GeneratedService struct {
delegate targetService
}
func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
if delegate == nil {
panic("provided nil delegate")
}
Expand Down Expand Up @@ -675,7 +676,7 @@ type GeneratedService struct {
delegate targetService
}
func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
if delegate == nil {
panic("provided nil delegate")
}
Expand Down Expand Up @@ -742,6 +743,121 @@ func (g *GeneratedService) SendReceiveDir(arg0 chan error) error {
}
return err
}
`,
},
},
},
},
{
name: "Unexported",
ignoreNoReturnMethods: true,
inputs: map[string]input{
"users_service.go": {
interfaceName: "service",
code: `package fake
type service interface {
SayHello(name string) error
}
`,
},
},
outCode: &generator.Generated{
Common: `// Code generated by reinforcer, DO NOT EDIT.
package resilient
import (
"context"
goresilience "github.com/slok/goresilience"
)
type base struct {
errorPredicate func(string, error) bool
runnerFactory runnerFactory
}
type runnerFactory interface {
GetRunner(name string) goresilience.Runner
}
var RetryAllErrors = func(_ string, _ error) bool {
return true
}
type Option func(*base)
func WithRetryableErrorPredicate(fn func(string, error) bool) Option {
return func(o *base) {
o.errorPredicate = fn
}
}
func (b *base) run(ctx context.Context, name string, fn func(ctx context.Context) error) error {
return b.runnerFactory.GetRunner(name).Run(ctx, fn)
}
`,
Constants: `// Code generated by reinforcer, DO NOT EDIT.
package resilient
// GeneratedServiceMethods are the methods in GeneratedService
var GeneratedServiceMethods = struct {
SayHello string
}{
SayHello: "SayHello",
}
`,
Files: []*generator.GeneratedFile{
{
TypeName: "GeneratedService",
Contents: `// Code generated by reinforcer, DO NOT EDIT.
package resilient
import "context"
type targetService interface {
SayHello(arg0 string) error
}
type GeneratedService struct {
*base
delegate targetService
}
func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService {
if delegate == nil {
panic("provided nil delegate")
}
if runnerFactory == nil {
panic("provided nil runner factory")
}
c := &GeneratedService{
base: &base{
errorPredicate: RetryAllErrors,
runnerFactory: runnerFactory,
},
delegate: delegate,
}
for _, o := range options {
o(c.base)
}
return c
}
func (g *GeneratedService) SayHello(arg0 string) error {
var nonRetryableErr error
err := g.run(context.Background(), GeneratedServiceMethods.SayHello, func(_ context.Context) error {
var err error
err = g.delegate.SayHello(arg0)
if g.errorPredicate(GeneratedServiceMethods.SayHello, err) {
return err
}
nonRetryableErr = err
return nil
})
if nonRetryableErr != nil {
return nonRetryableErr
}
return err
}
`,
},
},
Expand Down Expand Up @@ -803,11 +919,10 @@ func loadInterface(t *testing.T, filesCode map[string]input) []*generator.FileCo
for _, in := range filesCode {
svc, err := l.LoadOne(pkg, in.interfaceName, loader.PackageLoadMode)
require.NoError(t, err)
loadedTypes = append(loadedTypes, &generator.FileConfig{
SrcTypeName: in.interfaceName,
OutTypeName: fmt.Sprintf("Generated%s", in.interfaceName),
InterfaceType: svc.InterfaceType,
})
loadedTypes = append(loadedTypes, generator.NewFileConfig(in.interfaceName,
fmt.Sprintf("Generated%s", strings.Title(in.interfaceName)),
svc.InterfaceType,
))
}
return loadedTypes
}
2 changes: 1 addition & 1 deletion internal/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (l *Loader) loadExpr(path string, expr *regexp.Regexp, mode LoadMode) (*pac

func (l *Loader) load(path string, mode LoadMode) ([]*packages.Package, error) {
cfg := &packages.Config{
Mode: packages.NeedTypes | packages.NeedImports,
Mode: packages.NeedTypes | packages.NeedImports | packages.NeedSyntax,
}

var pkgs []*packages.Package
Expand Down
6 changes: 4 additions & 2 deletions internal/loader/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type HelloWorldService interface {
}
type unexportedService interface {
ShouldNotBeSeen()
HelloWorld()
}
type NotAnInterface struct {
Expand All @@ -71,11 +71,13 @@ type NotAnInterface struct {
results, err := l.LoadMatched("github.com/csueiras/fake", []string{".*Service"}, loader.PackageLoadMode)
require.NoError(t, err)
require.NotNil(t, results)
require.Equal(t, 2, len(results))
require.Equal(t, 3, len(results))
require.NotNil(t, results["UserService"])
require.Equal(t, "interface{GetUserID(ctx context.Context, userID string) (string, error)}", results["UserService"].InterfaceType.String())
require.NotNil(t, results["HelloWorldService"])
require.Equal(t, "interface{Hello(ctx context.Context, name string) error}", results["HelloWorldService"].InterfaceType.String())
require.NotNil(t, results["unexportedService"])
require.Equal(t, "interface{HelloWorld()}", results["unexportedService"].InterfaceType.String())
})

t.Run("Multiple RegEx Expressions", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/writer/io/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (F *FSOutputProvider) GetOutputTarget(filename string) (io.WriteCloser, err

filename = path.Base(filename)
fullPath := path.Join(dir, filename)
f, err := os.OpenFile(fullPath, os.O_RDWR|os.O_CREATE, 0755)
f, err := os.Create(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to open file %s; error=%w", fullPath, err)
}
Expand Down

0 comments on commit c3f98c2

Please sign in to comment.