diff --git a/internal/generator/executor/executor.go b/internal/generator/executor/executor.go index 1b99afe..126a68d 100644 --- a/internal/generator/executor/executor.go +++ b/internal/generator/executor/executor.go @@ -64,12 +64,7 @@ func (e *Executor) Execute(settings *Parameters) (*generator.Generated, error) { return nil, fmt.Errorf("multiple types with same name discovered with name %s", typName) } results[typName] = res.InterfaceType - - cfg = append(cfg, &generator.FileConfig{ - SrcTypeName: typName, - OutTypeName: typName, - InterfaceType: res.InterfaceType, - }) + cfg = append(cfg, generator.NewFileConfig(typName, typName, res.InterfaceType)) } } diff --git a/internal/generator/generator.go b/internal/generator/generator.go index a540559..fbc5eda 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -17,20 +17,29 @@ var fileHeader = "Code generated by reinforcer, DO NOT EDIT." // FileConfig holds the code generation configuration for a specific type type FileConfig struct { - // SrcTypeName is the source type that we want to generate code for - SrcTypeName string - // OutTypeName is the desired output type name - OutTypeName string - // InterfaceType holds the type information for SrcTypeName - InterfaceType *types.Interface + // srcTypeName is the source type that we want to generate code for + srcTypeName string + // outTypeName is the desired output type name + outTypeName string + // interfaceType holds the type information for SrcTypeName + interfaceType *types.Interface +} + +// NewFileConfig creates a new instance of the FileConfig which holds code generation configuration +func NewFileConfig(srcTypeName string, outTypeName string, interfaceType *types.Interface) *FileConfig { + return &FileConfig{ + srcTypeName: strings.Title(srcTypeName), + outTypeName: strings.Title(outTypeName), + interfaceType: interfaceType, + } } func (f *FileConfig) targetName() string { - return "target" + f.SrcTypeName + return "target" + f.srcTypeName } func (f *FileConfig) receiverName() string { - return strings.ToLower(f.OutTypeName[0:1]) + return strings.ToLower(f.outTypeName[0:1]) } // Config holds the code generation configuration for all of desired types @@ -88,7 +97,7 @@ func Generate(cfg Config) (*Generated, error) { var fileMethods []*fileMeta for _, fileConfig := range cfg.Files { - methods, err := parseMethods(fileConfig.OutTypeName, fileConfig.InterfaceType) + methods, err := parseMethods(fileConfig.outTypeName, fileConfig.interfaceType) if err != nil { return nil, err } @@ -97,7 +106,7 @@ func Generate(cfg Config) (*Generated, error) { return nil, err } gen.Files = append(gen.Files, &GeneratedFile{ - TypeName: fileConfig.OutTypeName, + TypeName: fileConfig.outTypeName, Contents: s, }) fileMethods = append(fileMethods, &fileMeta{fileConfig: fileConfig, methods: methods}) @@ -128,17 +137,17 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig )) // Declare the proxy implementation - f.Add(jen.Type().Id(fileCfg.OutTypeName).Struct( + f.Add(jen.Type().Id(fileCfg.outTypeName).Struct( jen.Op("*").Id("base"), jen.Id("delegate").Id(fileCfg.targetName()), )) // Declare the ctor - f.Add(jen.Func().Id("New"+fileCfg.SrcTypeName).Params( + f.Add(jen.Func().Id("New"+fileCfg.outTypeName).Params( jen.Id("delegate").Id(fileCfg.targetName()), jen.Id("runnerFactory").Id("runnerFactory"), jen.Id("options").Op("...").Id("Option"), - ).Op("*").Id(fileCfg.OutTypeName).Block( + ).Op("*").Id(fileCfg.outTypeName).Block( // if delegate == nil jen.If(jen.Id("delegate").Op("==").Nil().Block( // panic("...") @@ -150,7 +159,7 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig jen.Panic(jen.Lit("provided nil runner factory")), )), // c:= &OutTypeName{...} - jen.Id("c").Op(":=").Add(jen.Op("&").Id(fileCfg.OutTypeName).Values(jen.Dict{ + jen.Id("c").Op(":=").Add(jen.Op("&").Id(fileCfg.outTypeName).Values(jen.Dict{ // embed the base struct jen.Id("base"): jen.Op("&").Id("base").Values(jen.Dict{ jen.Id("errorPredicate"): jen.Id("RetryAllErrors"), @@ -168,7 +177,7 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig // Declare all of our proxy methods for _, mm := range methods { if mm.ReturnsError { - r := retryable.NewRetryable(mm, fileCfg.OutTypeName, fileCfg.receiverName()) + r := retryable.NewRetryable(mm, fileCfg.outTypeName, fileCfg.receiverName()) s, err := r.Statement() if err != nil { return "", err @@ -177,9 +186,9 @@ func generateFile(outPkg string, ignoreNoReturnMethods bool, fileCfg *FileConfig } else { var p statement if ignoreNoReturnMethods { - p = passthrough.NewPassThrough(mm, fileCfg.OutTypeName, fileCfg.receiverName()) + p = passthrough.NewPassThrough(mm, fileCfg.outTypeName, fileCfg.receiverName()) } else { - p = noret.NewNoReturn(mm, fileCfg.OutTypeName, fileCfg.receiverName()) + p = noret.NewNoReturn(mm, fileCfg.outTypeName, fileCfg.receiverName()) } s, err := p.Statement() if err != nil { @@ -244,9 +253,9 @@ func generateConstants(outPkg string, meta []*fileMeta) (string, error) { constantAssign = append(constantAssign, jen.Id(m.Name).Op(":").Lit(m.Name).Op(",")) } - constObjName := fmt.Sprintf("%sMethods", fm.fileConfig.OutTypeName) - log.Debug().Msgf("Adding constants for type %s", fm.fileConfig.OutTypeName) - f.Add(jen.Comment(fmt.Sprintf("%s are the methods in %s", constObjName, fm.fileConfig.OutTypeName))) + constObjName := fmt.Sprintf("%sMethods", fm.fileConfig.outTypeName) + log.Debug().Msgf("Adding constants for type %s", fm.fileConfig.outTypeName) + f.Add(jen.Comment(fmt.Sprintf("%s are the methods in %s", constObjName, fm.fileConfig.outTypeName))) f.Add( jen.Var().Id(constObjName).Op("=").Struct( fields..., diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index d2601e4..48a5073 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/packages/packagestest" + "strings" "testing" ) @@ -100,7 +101,7 @@ type GeneratedService struct { delegate targetService } -func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { +func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { if delegate == nil { panic("provided nil delegate") } @@ -243,7 +244,7 @@ type GeneratedService struct { delegate targetService } -func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { +func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { if delegate == nil { panic("provided nil delegate") } @@ -428,7 +429,7 @@ type GeneratedService struct { delegate targetService } -func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { +func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { if delegate == nil { panic("provided nil delegate") } @@ -553,7 +554,7 @@ type GeneratedService struct { delegate targetService } -func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { +func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { if delegate == nil { panic("provided nil delegate") } @@ -675,7 +676,7 @@ type GeneratedService struct { delegate targetService } -func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { +func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { if delegate == nil { panic("provided nil delegate") } @@ -742,6 +743,121 @@ func (g *GeneratedService) SendReceiveDir(arg0 chan error) error { } return err } +`, + }, + }, + }, + }, + { + name: "Unexported", + ignoreNoReturnMethods: true, + inputs: map[string]input{ + "users_service.go": { + interfaceName: "service", + code: `package fake + +type service interface { + SayHello(name string) error +} +`, + }, + }, + outCode: &generator.Generated{ + Common: `// Code generated by reinforcer, DO NOT EDIT. + +package resilient + +import ( + "context" + goresilience "github.com/slok/goresilience" +) + +type base struct { + errorPredicate func(string, error) bool + runnerFactory runnerFactory +} +type runnerFactory interface { + GetRunner(name string) goresilience.Runner +} + +var RetryAllErrors = func(_ string, _ error) bool { + return true +} + +type Option func(*base) + +func WithRetryableErrorPredicate(fn func(string, error) bool) Option { + return func(o *base) { + o.errorPredicate = fn + } +} +func (b *base) run(ctx context.Context, name string, fn func(ctx context.Context) error) error { + return b.runnerFactory.GetRunner(name).Run(ctx, fn) +} +`, + Constants: `// Code generated by reinforcer, DO NOT EDIT. + +package resilient + +// GeneratedServiceMethods are the methods in GeneratedService +var GeneratedServiceMethods = struct { + SayHello string +}{ + SayHello: "SayHello", +} +`, + Files: []*generator.GeneratedFile{ + { + TypeName: "GeneratedService", + Contents: `// Code generated by reinforcer, DO NOT EDIT. + +package resilient + +import "context" + +type targetService interface { + SayHello(arg0 string) error +} +type GeneratedService struct { + *base + delegate targetService +} + +func NewGeneratedService(delegate targetService, runnerFactory runnerFactory, options ...Option) *GeneratedService { + if delegate == nil { + panic("provided nil delegate") + } + if runnerFactory == nil { + panic("provided nil runner factory") + } + c := &GeneratedService{ + base: &base{ + errorPredicate: RetryAllErrors, + runnerFactory: runnerFactory, + }, + delegate: delegate, + } + for _, o := range options { + o(c.base) + } + return c +} +func (g *GeneratedService) SayHello(arg0 string) error { + var nonRetryableErr error + err := g.run(context.Background(), GeneratedServiceMethods.SayHello, func(_ context.Context) error { + var err error + err = g.delegate.SayHello(arg0) + if g.errorPredicate(GeneratedServiceMethods.SayHello, err) { + return err + } + nonRetryableErr = err + return nil + }) + if nonRetryableErr != nil { + return nonRetryableErr + } + return err +} `, }, }, @@ -803,11 +919,10 @@ func loadInterface(t *testing.T, filesCode map[string]input) []*generator.FileCo for _, in := range filesCode { svc, err := l.LoadOne(pkg, in.interfaceName, loader.PackageLoadMode) require.NoError(t, err) - loadedTypes = append(loadedTypes, &generator.FileConfig{ - SrcTypeName: in.interfaceName, - OutTypeName: fmt.Sprintf("Generated%s", in.interfaceName), - InterfaceType: svc.InterfaceType, - }) + loadedTypes = append(loadedTypes, generator.NewFileConfig(in.interfaceName, + fmt.Sprintf("Generated%s", strings.Title(in.interfaceName)), + svc.InterfaceType, + )) } return loadedTypes } diff --git a/internal/loader/loader.go b/internal/loader/loader.go index f176dbc..f39288f 100644 --- a/internal/loader/loader.go +++ b/internal/loader/loader.go @@ -137,7 +137,7 @@ func (l *Loader) loadExpr(path string, expr *regexp.Regexp, mode LoadMode) (*pac func (l *Loader) load(path string, mode LoadMode) ([]*packages.Package, error) { cfg := &packages.Config{ - Mode: packages.NeedTypes | packages.NeedImports, + Mode: packages.NeedTypes | packages.NeedImports | packages.NeedSyntax, } var pkgs []*packages.Package diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 4065a2f..b9d7dad 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -52,7 +52,7 @@ type HelloWorldService interface { } type unexportedService interface { - ShouldNotBeSeen() + HelloWorld() } type NotAnInterface struct { @@ -71,11 +71,13 @@ type NotAnInterface struct { results, err := l.LoadMatched("github.com/csueiras/fake", []string{".*Service"}, loader.PackageLoadMode) require.NoError(t, err) require.NotNil(t, results) - require.Equal(t, 2, len(results)) + require.Equal(t, 3, len(results)) require.NotNil(t, results["UserService"]) require.Equal(t, "interface{GetUserID(ctx context.Context, userID string) (string, error)}", results["UserService"].InterfaceType.String()) require.NotNil(t, results["HelloWorldService"]) require.Equal(t, "interface{Hello(ctx context.Context, name string) error}", results["HelloWorldService"].InterfaceType.String()) + require.NotNil(t, results["unexportedService"]) + require.Equal(t, "interface{HelloWorld()}", results["unexportedService"].InterfaceType.String()) }) t.Run("Multiple RegEx Expressions", func(t *testing.T) { diff --git a/internal/writer/io/io.go b/internal/writer/io/io.go index da8ce5f..7087086 100644 --- a/internal/writer/io/io.go +++ b/internal/writer/io/io.go @@ -41,7 +41,7 @@ func (F *FSOutputProvider) GetOutputTarget(filename string) (io.WriteCloser, err filename = path.Base(filename) fullPath := path.Join(dir, filename) - f, err := os.OpenFile(fullPath, os.O_RDWR|os.O_CREATE, 0755) + f, err := os.Create(fullPath) if err != nil { return nil, fmt.Errorf("failed to open file %s; error=%w", fullPath, err) }