From 5f14fd191d1124adea3de9123838d8bba3a8188b Mon Sep 17 00:00:00 2001 From: Christian Sueiras Date: Sun, 7 Mar 2021 18:44:00 -0500 Subject: [PATCH 1/5] Support extracting contract from a struct. Fix to expressions handling from targets. --- example/client/client.go | 15 +- .../client/reinforced/reinforcer_constants.go | 7 + example/client/reinforced/service.go | 50 +++++++ internal/generator/executor/executor.go | 10 +- internal/generator/executor/executor_test.go | 15 +- internal/generator/generator.go | 31 +--- internal/generator/generator_test.go | 2 +- internal/generator/method/method.go | 69 ++------- internal/generator/method/method_test.go | 11 +- internal/generator/noret/noret.go | 2 +- internal/generator/noret/noret_test.go | 16 +-- .../generator/passthrough/passthrough_test.go | 5 +- internal/generator/retryable/retryable.go | 7 +- .../generator/retryable/retryable_test.go | 31 ++-- internal/loader/loader.go | 136 +++++++++++++++--- internal/loader/loader_test.go | 112 +++++++++++---- internal/testpkg/teststruct.go | 20 +++ internal/types/types.go | 60 ++++++++ 18 files changed, 419 insertions(+), 180 deletions(-) create mode 100644 example/client/reinforced/service.go create mode 100644 internal/testpkg/teststruct.go create mode 100644 internal/types/types.go diff --git a/example/client/client.go b/example/client/client.go index b5126b8..b3479bb 100644 --- a/example/client/client.go +++ b/example/client/client.go @@ -1,4 +1,4 @@ -//go:generate reinforcer --debug --target=Client --target=SomeOtherClient --outputdir=./reinforced +//go:generate reinforcer --debug --target=Client --target=SomeOtherClient --target=Service --outputdir=./reinforced package client @@ -26,10 +26,21 @@ type SomeOtherClient interface { DoStuff() error SaveFile(myFile *File, osFile *os.File) error GetUser(ctx context.Context) (*sub.User, error) - MethodWithChannel(myChan <- chan bool) error + MethodWithChannel(myChan <-chan bool) error MethodWithWildcard(arg interface{}) } +// Service is an example of a struct defined contract that will be reversed engineered by reinforcer +type Service struct{} + +// GetData retrieves data it might randomly error out +func (s *Service) GetData() ([]byte, error) { + if rand.Int()%5 == 0 { + return nil, fmt.Errorf("random failure") + } + return []byte{0xB, 0xE, 0xE, 0xF}, nil +} + // FakeClient is a Client implementation that will randomly fail type FakeClient struct { } diff --git a/example/client/reinforced/reinforcer_constants.go b/example/client/reinforced/reinforcer_constants.go index f04b7f1..7b7077e 100755 --- a/example/client/reinforced/reinforcer_constants.go +++ b/example/client/reinforced/reinforcer_constants.go @@ -11,6 +11,13 @@ var ClientMethods = struct { SayHello: "SayHello", } +// ServiceMethods are the methods in Service +var ServiceMethods = struct { + GetData string +}{ + GetData: "GetData", +} + // SomeOtherClientMethods are the methods in SomeOtherClient var SomeOtherClientMethods = struct { DoStuff string diff --git a/example/client/reinforced/service.go b/example/client/reinforced/service.go new file mode 100644 index 0000000..9e93f17 --- /dev/null +++ b/example/client/reinforced/service.go @@ -0,0 +1,50 @@ +// Code generated by reinforcer, DO NOT EDIT. + +package reinforced + +import "context" + +type targetService interface { + GetData() ([]byte, error) +} +type Service struct { + *base + delegate targetService +} + +func NewService(delegate targetService, runnerFactory runnerFactory, options ...Option) *Service { + if delegate == nil { + panic("provided nil delegate") + } + if runnerFactory == nil { + panic("provided nil runner factory") + } + c := &Service{ + base: &base{ + errorPredicate: RetryAllErrors, + runnerFactory: runnerFactory, + }, + delegate: delegate, + } + for _, o := range options { + o(c.base) + } + return c +} +func (s *Service) GetData() ([]byte, error) { + var nonRetryableErr error + var r0 []byte + err := s.run(context.Background(), ServiceMethods.GetData, func(_ context.Context) error { + var err error + r0, err = s.delegate.GetData() + if s.errorPredicate(ServiceMethods.GetData, err) { + return err + } + nonRetryableErr = err + return nil + }) + if nonRetryableErr != nil { + return r0, nonRetryableErr + } + return r0, err +} diff --git a/internal/generator/executor/executor.go b/internal/generator/executor/executor.go index 126a68d..e0e0017 100644 --- a/internal/generator/executor/executor.go +++ b/internal/generator/executor/executor.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/csueiras/reinforcer/internal/generator" "github.com/csueiras/reinforcer/internal/loader" - "go/types" ) // ErrNoTargetableTypesFound indicates that no types that could be targeted for code generation were discovered @@ -44,7 +43,8 @@ func New(l Loader) *Executor { // Execute orchestrates code generation sourced from multiple files/targets func (e *Executor) Execute(settings *Parameters) (*generator.Generated, error) { - results := make(map[string]*types.Interface) + discoveredTypes := make(map[string]struct{}) + var cfg []*generator.FileConfig var err error for _, source := range settings.Sources { @@ -60,11 +60,11 @@ func (e *Executor) Execute(settings *Parameters) (*generator.Generated, error) { // Check types aren't repeated before adding them to the generator's config for typName, res := range match { - if _, ok := results[typName]; ok { + if _, ok := discoveredTypes[typName]; ok { return nil, fmt.Errorf("multiple types with same name discovered with name %s", typName) } - results[typName] = res.InterfaceType - cfg = append(cfg, generator.NewFileConfig(typName, typName, res.InterfaceType)) + discoveredTypes[typName] = struct{}{} + cfg = append(cfg, generator.NewFileConfig(typName, typName, res.Methods)) } } diff --git a/internal/generator/executor/executor_test.go b/internal/generator/executor/executor_test.go index b749b2c..8832d42 100644 --- a/internal/generator/executor/executor_test.go +++ b/internal/generator/executor/executor_test.go @@ -3,9 +3,9 @@ package executor_test import ( "github.com/csueiras/reinforcer/internal/generator/executor" "github.com/csueiras/reinforcer/internal/generator/executor/mocks" + "github.com/csueiras/reinforcer/internal/generator/method" "github.com/csueiras/reinforcer/internal/loader" "github.com/stretchr/testify/require" - "go/token" "go/types" "testing" ) @@ -16,8 +16,8 @@ func TestExecutor_Execute(t *testing.T) { l.On("LoadMatched", "./testpkg.go", []string{"MyService"}, loader.FileLoadMode).Return( map[string]*loader.Result{ "LockService": { - Name: "LockService", - InterfaceType: createTestInterfaceType(), + Name: "LockService", + Methods: createTestServiceMethods(), }, }, nil, ) @@ -52,11 +52,10 @@ func TestExecutor_Execute(t *testing.T) { }) } -func createTestInterfaceType() *types.Interface { +func createTestServiceMethods() []*method.Method { nullary := types.NewSignature(nil, nil, nil, false) // func() - methods := []*types.Func{ - types.NewFunc(token.NoPos, nil, "Lock", nullary), - types.NewFunc(token.NoPos, nil, "Unlock", nullary), + return []*method.Method{ + method.MustParseMethod("Lock", nullary), + method.MustParseMethod("Unlock", nullary), } - return types.NewInterfaceType(methods, nil).Complete() } diff --git a/internal/generator/generator.go b/internal/generator/generator.go index fbc5eda..3f18c19 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -9,7 +9,6 @@ import ( "github.com/csueiras/reinforcer/internal/generator/retryable" "github.com/dave/jennifer/jen" "github.com/rs/zerolog/log" - "go/types" "strings" ) @@ -21,16 +20,16 @@ type FileConfig struct { srcTypeName string // outTypeName is the desired output type name outTypeName string - // interfaceType holds the type information for SrcTypeName - interfaceType *types.Interface + // methods that should be in the generated type + methods []*method.Method } // NewFileConfig creates a new instance of the FileConfig which holds code generation configuration -func NewFileConfig(srcTypeName string, outTypeName string, interfaceType *types.Interface) *FileConfig { +func NewFileConfig(srcTypeName, outTypeName string, methods []*method.Method) *FileConfig { return &FileConfig{ - srcTypeName: strings.Title(srcTypeName), - outTypeName: strings.Title(outTypeName), - interfaceType: interfaceType, + srcTypeName: strings.Title(srcTypeName), + outTypeName: strings.Title(outTypeName), + methods: methods, } } @@ -97,10 +96,7 @@ func Generate(cfg Config) (*Generated, error) { var fileMethods []*fileMeta for _, fileConfig := range cfg.Files { - methods, err := parseMethods(fileConfig.outTypeName, fileConfig.interfaceType) - if err != nil { - return nil, err - } + methods := fileConfig.methods s, err := generateFile(cfg.OutPkg, cfg.IgnoreNoReturnMethods, fileConfig, methods) if err != nil { return nil, err @@ -268,19 +264,6 @@ func generateConstants(outPkg string, meta []*fileMeta) (string, error) { return renderToString(f) } -func parseMethods(typeName string, interfaceType *types.Interface) ([]*method.Method, error) { - var methods []*method.Method - for m := 0; m < interfaceType.NumMethods(); m++ { - meth := interfaceType.Method(m) - mm, err := method.ParseMethod(typeName, meth.Name(), meth.Type().(*types.Signature)) - if err != nil { - return nil, err - } - methods = append(methods, mm) - } - return methods, nil -} - func renderToString(f *jen.File) (string, error) { b := &bytes.Buffer{} if err := f.Render(b); err != nil { diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index 48a5073..fd0ce58 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -921,7 +921,7 @@ func loadInterface(t *testing.T, filesCode map[string]input) []*generator.FileCo require.NoError(t, err) loadedTypes = append(loadedTypes, generator.NewFileConfig(in.interfaceName, fmt.Sprintf("Generated%s", strings.Title(in.interfaceName)), - svc.InterfaceType, + svc.Methods, )) } return loadedTypes diff --git a/internal/generator/method/method.go b/internal/generator/method/method.go index e94528a..9b12198 100644 --- a/internal/generator/method/method.go +++ b/internal/generator/method/method.go @@ -2,7 +2,7 @@ package method import ( "fmt" - "github.com/csueiras/reinforcer/internal/loader" + rtypes "github.com/csueiras/reinforcer/internal/types" "github.com/dave/jennifer/jen" "go/types" ) @@ -15,36 +15,8 @@ type named interface { Name() string } -// ErrType is the types.Type for the error interface -var ErrType types.Type - -// ContextType is the types.Type for the context.Context interface -var ContextType *types.Interface - -func init() { - errType := types.NewInterfaceType([]*types.Func{ - types.NewFunc(0, nil, "Error", - types.NewSignature( - nil, - types.NewTuple(), - types.NewTuple(types.NewParam(0, nil, "", types.Typ[types.String])), - false, - ), - ), - }, nil) - errType.Complete() - ErrType = types.NewNamed(types.NewTypeName(0, nil, "error", nil), errType, nil) - - iface, err := loader.DefaultLoader().LoadOne("context", "Context", loader.PackageLoadMode) - if err != nil { - panic(err) - } - ContextType = iface.InterfaceType -} - // Method holds all of the data for code generation on a specific method signature type Method struct { - ParentTypeName string Name string HasContext bool ReturnsError bool @@ -57,8 +29,8 @@ type Method struct { } // ConstantRef is the reference to the constant for this method's name -func (m *Method) ConstantRef() jen.Code { - constantsStructName := fmt.Sprintf("%sMethods", m.ParentTypeName) +func (m *Method) ConstantRef(parentTypeName string) jen.Code { + constantsStructName := fmt.Sprintf("%sMethods", parentTypeName) return jen.Id(constantsStructName).Dot(m.Name) } @@ -89,10 +61,18 @@ func (m *Method) Parameters() []jen.Code { return params } +// MustParseMethod parses the given types.Signature and generates a Method, if there's an error this method will panic +func MustParseMethod(name string, signature *types.Signature) *Method { + m, err := ParseMethod(name, signature) + if err != nil { + panic(err) + } + return m +} + // ParseMethod parses the given types.Signature and generates a Method -func ParseMethod(parentTypeName, name string, signature *types.Signature) (*Method, error) { +func ParseMethod(name string, signature *types.Signature) (*Method, error) { m := &Method{ - ParentTypeName: parentTypeName, Name: name, ReturnErrorIndex: nil, ContextParameter: nil, @@ -103,7 +83,7 @@ func ParseMethod(parentTypeName, name string, signature *types.Signature) (*Meth numParams := signature.Params().Len() for i, lastIndex := 0, numParams-1; i < numParams; i++ { param := signature.Params().At(i) - if isContextType(param.Type()) { + if rtypes.IsContextType(param.Type()) { m.HasContext = true m.ContextParameter = new(int) *m.ContextParameter = i @@ -126,7 +106,7 @@ func ParseMethod(parentTypeName, name string, signature *types.Signature) (*Meth if err != nil { panic(err) } - if isErrorType(res.Type()) { + if rtypes.IsErrorType(res.Type()) { if m.ReturnErrorIndex != nil { return nil, fmt.Errorf("multiple errors returned by method signature") } @@ -139,25 +119,6 @@ func ParseMethod(parentTypeName, name string, signature *types.Signature) (*Meth return m, nil } -// isErrorType determines if the given type implements the Error interface -func isErrorType(t types.Type) bool { - if t == nil { - return false - } - return types.Implements(t, ErrType.Underlying().(*types.Interface)) -} - -// isContextType determines if the given type is context.Context -func isContextType(t types.Type) bool { - if t == nil { - return false - } - if t.String() == "context.Context" { - return true - } - return types.Implements(t, ContextType) -} - // variadicToType generates the representation for a variadic type "...MyType" func variadicToType(t types.Type) (jen.Code, error) { sliceType, ok := t.(*types.Slice) diff --git a/internal/generator/method/method_test.go b/internal/generator/method/method_test.go index 99db205..286bf80 100644 --- a/internal/generator/method/method_test.go +++ b/internal/generator/method/method_test.go @@ -3,6 +3,7 @@ package method_test import ( "fmt" "github.com/csueiras/reinforcer/internal/generator/method" + rtypes "github.com/csueiras/reinforcer/internal/types" "github.com/dave/jennifer/jen" "github.com/stretchr/testify/require" "go/token" @@ -11,7 +12,7 @@ import ( ) func TestNewMethod(t *testing.T) { - ctxVar := types.NewVar(token.NoPos, nil, "ctx", method.ContextType) + ctxVar := types.NewVar(token.NoPos, nil, "ctx", rtypes.ContextType) zero := new(int) *zero = 0 @@ -105,7 +106,7 @@ func TestNewMethod(t *testing.T) { signature: types.NewSignature(nil, types.NewTuple( ctxVar, types.NewVar(token.NoPos, nil, "myArg", types.Typ[types.String]), - ), types.NewTuple(types.NewVar(token.NoPos, nil, "", method.ErrType)), false), + ), types.NewTuple(types.NewVar(token.NoPos, nil, "", rtypes.ErrType)), false), }, want: &method.Method{ Name: "Fn", @@ -125,7 +126,7 @@ func TestNewMethod(t *testing.T) { types.NewVar(token.NoPos, nil, "myArg", types.Typ[types.String]), ), types.NewTuple( types.NewVar(token.NoPos, nil, "", types.Typ[types.String]), - types.NewVar(token.NoPos, nil, "", method.ErrType), + types.NewVar(token.NoPos, nil, "", rtypes.ErrType), ), false), }, want: &method.Method{ @@ -175,7 +176,7 @@ func TestNewMethod(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := method.ParseMethod("ParentType", tt.args.name, tt.args.signature) + got, err := method.ParseMethod(tt.args.name, tt.args.signature) require.NoError(t, err) require.Equal(t, tt.want.Name, got.Name) require.Equal(t, tt.want.HasContext, got.HasContext) @@ -189,7 +190,7 @@ func TestNewMethod(t *testing.T) { require.ElementsMatch(t, tt.want.ParameterNames, got.ParameterNames) require.ElementsMatch(t, tt.want.ParametersNameAndType, got.ParametersNameAndType) require.ElementsMatch(t, tt.want.ReturnTypes, got.ReturnTypes) - require.Equal(t, fmt.Sprintf("ParentTypeMethods.%s", tt.want.Name), (got.ConstantRef().(*jen.Statement)).GoString()) + require.Equal(t, fmt.Sprintf("ParentTypeMethods.%s", tt.want.Name), (got.ConstantRef("ParentType").(*jen.Statement)).GoString()) }) } } diff --git a/internal/generator/noret/noret.go b/internal/generator/noret/noret.go index 4b6166b..aa82898 100644 --- a/internal/generator/noret/noret.go +++ b/internal/generator/noret/noret.go @@ -36,7 +36,7 @@ func (p *NoReturn) Statement() (*jen.Statement, error) { ) return jen.Func().Params(jen.Id(p.receiverName).Op("*").Id(p.structName)).Id(p.method.Name).Call(methodArgParams...).Block( - jen.Id("err").Op(":=").Id(p.receiverName).Dot("run").Call(ctxParam, p.method.ConstantRef(), call), + jen.Id("err").Op(":=").Id(p.receiverName).Dot("run").Call(ctxParam, p.method.ConstantRef(p.structName), call), jen.If(jen.Id("err").Op("!=").Nil()).Block( jen.Panic(jen.Id("err")), ), diff --git a/internal/generator/noret/noret_test.go b/internal/generator/noret/noret_test.go index 5e9cb07..3126f7d 100644 --- a/internal/generator/noret/noret_test.go +++ b/internal/generator/noret/noret_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/csueiras/reinforcer/internal/generator/method" "github.com/csueiras/reinforcer/internal/generator/noret" + rtypes "github.com/csueiras/reinforcer/internal/types" "github.com/stretchr/testify/require" "go/token" "go/types" @@ -11,7 +12,7 @@ import ( ) func TestNoReturn_Statement(t *testing.T) { - ctxVar := types.NewVar(token.NoPos, nil, "ctx", method.ContextType) + ctxVar := types.NewVar(token.NoPos, nil, "ctx", rtypes.ContextType) tests := []struct { name string @@ -24,8 +25,8 @@ func TestNoReturn_Statement(t *testing.T) { name: "MyFunction()", methodName: "MyFunction", signature: types.NewSignature(nil, types.NewTuple(), types.NewTuple(), false), - want: `func (r *resilient) MyFunction() { - err := r.run(context.Background(), ParentMethods.MyFunction, func(_ context.Context) error { + want: `func (r *Resilient) MyFunction() { + err := r.run(context.Background(), ResilientMethods.MyFunction, func(_ context.Context) error { r.delegate.MyFunction() return nil }) @@ -42,8 +43,8 @@ func TestNoReturn_Statement(t *testing.T) { ctxVar, types.NewVar(token.NoPos, nil, "myArg", types.Typ[types.String]), ), types.NewTuple(types.NewVar(token.NoPos, nil, "", types.Typ[types.String])), false), - want: `func (r *resilient) MyFunction(ctx context.Context, arg1 string) { - err := r.run(ctx, ParentMethods.MyFunction, func(ctx context.Context) error { + want: `func (r *Resilient) MyFunction(ctx context.Context, arg1 string) { + err := r.run(ctx, ResilientMethods.MyFunction, func(ctx context.Context) error { r.delegate.MyFunction(ctx, arg1) return nil }) @@ -57,9 +58,9 @@ func TestNoReturn_Statement(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m, err := method.ParseMethod("Parent", tt.methodName, tt.signature) + m, err := method.ParseMethod(tt.methodName, tt.signature) require.NoError(t, err) - ret := noret.NewNoReturn(m, "resilient", "r") + ret := noret.NewNoReturn(m, "Resilient", "r") buf := &bytes.Buffer{} s, err := ret.Statement() if tt.wantErr { @@ -71,7 +72,6 @@ func TestNoReturn_Statement(t *testing.T) { require.NoError(t, renderErr) got := buf.String() - //fmt.Print(got) require.Equal(t, tt.want, got) } }) diff --git a/internal/generator/passthrough/passthrough_test.go b/internal/generator/passthrough/passthrough_test.go index 8296e47..437e726 100644 --- a/internal/generator/passthrough/passthrough_test.go +++ b/internal/generator/passthrough/passthrough_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/csueiras/reinforcer/internal/generator/method" "github.com/csueiras/reinforcer/internal/generator/passthrough" + rtypes "github.com/csueiras/reinforcer/internal/types" "github.com/stretchr/testify/require" "go/token" "go/types" @@ -11,7 +12,7 @@ import ( ) func TestPassThrough_Statement(t *testing.T) { - ctxVar := types.NewVar(token.NoPos, nil, "ctx", method.ContextType) + ctxVar := types.NewVar(token.NoPos, nil, "ctx", rtypes.ContextType) tests := []struct { name string @@ -57,7 +58,7 @@ func TestPassThrough_Statement(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m, err := method.ParseMethod("ParentType", tt.methodName, tt.signature) + m, err := method.ParseMethod(tt.methodName, tt.signature) require.NoError(t, err) ret := passthrough.NewPassThrough(m, "resilient", "r") buf := &bytes.Buffer{} diff --git a/internal/generator/retryable/retryable.go b/internal/generator/retryable/retryable.go index 4188222..cdf3b9c 100644 --- a/internal/generator/retryable/retryable.go +++ b/internal/generator/retryable/retryable.go @@ -13,7 +13,8 @@ const ( // Retryable is a code generator for a method that can be retried on error type Retryable struct { - method *method.Method + method *method.Method + //originalTypeName string structName string receiverName string } @@ -80,7 +81,7 @@ func (r *Retryable) methodCall() ([]jen.Code, error) { // if r.errorPredicate(methodName, err) { // return err // } - jen.If(jen.Id(r.receiverName).Dot("errorPredicate").Call(r.method.ConstantRef(), jen.Id(errVarName))).Block( + jen.If(jen.Id(r.receiverName).Dot("errorPredicate").Call(r.method.ConstantRef(r.structName), jen.Id(errVarName))).Block( jen.Return(jen.Id("err")), ), // nonRetryableErr = err @@ -89,7 +90,7 @@ func (r *Retryable) methodCall() ([]jen.Code, error) { jen.Return(jen.Nil()), ) - statements = append(statements, jen.Id("err").Op(":=").Id(r.receiverName).Dot("run").Call(ctxParam, r.method.ConstantRef(), call)) + statements = append(statements, jen.Id("err").Op(":=").Id(r.receiverName).Dot("run").Call(ctxParam, r.method.ConstantRef(r.structName), call)) nonRetryErrReturns := make([]jen.Code, len(returnVars)) copy(nonRetryErrReturns, returnVars) diff --git a/internal/generator/retryable/retryable_test.go b/internal/generator/retryable/retryable_test.go index 9165f61..e3675e2 100644 --- a/internal/generator/retryable/retryable_test.go +++ b/internal/generator/retryable/retryable_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/csueiras/reinforcer/internal/generator/method" "github.com/csueiras/reinforcer/internal/generator/retryable" + rtypes "github.com/csueiras/reinforcer/internal/types" "github.com/stretchr/testify/require" "go/token" "go/types" @@ -11,8 +12,8 @@ import ( ) func TestRetryable_Statement(t *testing.T) { - errVar := types.NewVar(token.NoPos, nil, "", method.ErrType) - ctxVar := types.NewVar(token.NoPos, nil, "ctx", method.ContextType) + errVar := types.NewVar(token.NoPos, nil, "", rtypes.ErrType) + ctxVar := types.NewVar(token.NoPos, nil, "ctx", rtypes.ContextType) tests := []struct { name string @@ -25,12 +26,12 @@ func TestRetryable_Statement(t *testing.T) { name: "Function returns error", methodName: "MyFunction", signature: types.NewSignature(nil, types.NewTuple(), types.NewTuple(errVar), false), - want: `func (r *resilient) MyFunction() error { + want: `func (r *Resilient) MyFunction() error { var nonRetryableErr error - err := r.run(context.Background(), ParentMethods.MyFunction, func(_ context.Context) error { + err := r.run(context.Background(), ResilientMethods.MyFunction, func(_ context.Context) error { var err error err = r.delegate.MyFunction() - if r.errorPredicate(ParentMethods.MyFunction, err) { + if r.errorPredicate(ResilientMethods.MyFunction, err) { return err } nonRetryableErr = err @@ -47,13 +48,13 @@ func TestRetryable_Statement(t *testing.T) { name: "Function returns string and error", methodName: "MyFunction", signature: types.NewSignature(nil, types.NewTuple(), types.NewTuple(types.NewVar(token.NoPos, nil, "", types.Typ[types.String]), errVar), false), - want: `func (r *resilient) MyFunction() (string, error) { + want: `func (r *Resilient) MyFunction() (string, error) { var nonRetryableErr error var r0 string - err := r.run(context.Background(), ParentMethods.MyFunction, func(_ context.Context) error { + err := r.run(context.Background(), ResilientMethods.MyFunction, func(_ context.Context) error { var err error r0, err = r.delegate.MyFunction() - if r.errorPredicate(ParentMethods.MyFunction, err) { + if r.errorPredicate(ResilientMethods.MyFunction, err) { return err } nonRetryableErr = err @@ -73,13 +74,13 @@ func TestRetryable_Statement(t *testing.T) { ctxVar, types.NewVar(token.NoPos, nil, "myArg", types.Typ[types.String]), ), types.NewTuple(types.NewVar(token.NoPos, nil, "", types.Typ[types.String]), errVar), false), - want: `func (r *resilient) MyFunction(ctx context.Context, arg1 string) (string, error) { + want: `func (r *Resilient) MyFunction(ctx context.Context, arg1 string) (string, error) { var nonRetryableErr error var r0 string - err := r.run(ctx, ParentMethods.MyFunction, func(ctx context.Context) error { + err := r.run(ctx, ResilientMethods.MyFunction, func(ctx context.Context) error { var err error r0, err = r.delegate.MyFunction(ctx, arg1) - if r.errorPredicate(ParentMethods.MyFunction, err) { + if r.errorPredicate(ResilientMethods.MyFunction, err) { return err } nonRetryableErr = err @@ -96,9 +97,9 @@ func TestRetryable_Statement(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m, err := method.ParseMethod("Parent", tt.methodName, tt.signature) + m, err := method.ParseMethod(tt.methodName, tt.signature) require.NoError(t, err) - ret := retryable.NewRetryable(m, "resilient", "r") + ret := retryable.NewRetryable(m, "Resilient", "r") buf := &bytes.Buffer{} s, err := ret.Statement() if tt.wantErr { @@ -117,9 +118,9 @@ func TestRetryable_Statement(t *testing.T) { t.Run("Function does not return error", func(t *testing.T) { require.Panics(t, func() { - m, err := method.ParseMethod("Parent", "Fn", types.NewSignature(nil, types.NewTuple(), types.NewTuple(), false)) + m, err := method.ParseMethod("Fn", types.NewSignature(nil, types.NewTuple(), types.NewTuple(), false)) require.NoError(t, err) - retryable.NewRetryable(m, "resilient", "r") + retryable.NewRetryable(m, "Resilient", "r") }) }) } diff --git a/internal/loader/loader.go b/internal/loader/loader.go index f39288f..9e76bc3 100644 --- a/internal/loader/loader.go +++ b/internal/loader/loader.go @@ -2,12 +2,15 @@ package loader import ( "fmt" + "github.com/csueiras/reinforcer/internal/generator/method" "github.com/rs/zerolog/log" + "go/ast" "go/types" "golang.org/x/tools/go/packages" "path/filepath" "regexp" "strings" + "unicode" ) // LoadMode determines how a path should be loaded @@ -38,8 +41,8 @@ func (l *LoadingError) Error() string { // Result holds the results of loading a particular type type Result struct { - Name string - InterfaceType *types.Interface + Name string + Methods []*method.Method } // Loader is a utility service for extracting type information from a go package @@ -114,22 +117,42 @@ func (l *Loader) loadExpr(path string, expr *regexp.Regexp, mode LoadMode) (*pac pkg := pkgs[0] typesFound := pkg.Types.Scope().Names() results := make(map[string]*Result) + + var matchingTypes []string for _, typeFound := range typesFound { if expr.MatchString(typeFound) { - obj := pkg.Types.Scope().Lookup(typeFound) - if obj == nil { - return nil, nil, fmt.Errorf("%s not found in declared types of %s", typeFound, pkg) + matchingTypes = append(matchingTypes, typeFound) + } + } + + log.Info().Msgf("Matching types to target expressions: %s", strings.Join(matchingTypes, ", ")) + + for _, typeFound := range matchingTypes { + obj := pkg.Types.Scope().Lookup(typeFound) + if obj == nil { + return nil, nil, fmt.Errorf("%s not found in declared types of %s", typeFound, pkg) + } + + switch typ := obj.Type().Underlying().(type) { + case *types.Interface: + log.Info().Msgf("Discovered interface type %s", typeFound) + result, err := loadFromInterface(typeFound, typ) + if err != nil { + return nil, nil, err } - interfaceType, ok := obj.Type().Underlying().(*types.Interface) - if !ok { - log.Debug().Msgf("Ignoring matching type %s because it is not an interface type", typeFound) - continue + results[typeFound] = result + case *types.Struct: + log.Info().Msgf("Discovered struct type %s", typeFound) + result, err := loadFromStruct(pkg.Syntax[0], typeFound, pkg.TypesInfo) + if err != nil { + return nil, nil, err } - log.Info().Msgf("Discovered type %s", typeFound) - results[typeFound] = &Result{ - Name: typeFound, - InterfaceType: interfaceType, + + if len(result.Methods) > 0 { + results[typeFound] = result } + default: + log.Debug().Msgf("Ignoring matching type %s because it is not an interface nor struct type", typeFound) } } return pkg, results, nil @@ -137,7 +160,7 @@ func (l *Loader) loadExpr(path string, expr *regexp.Regexp, mode LoadMode) (*pac func (l *Loader) load(path string, mode LoadMode) ([]*packages.Package, error) { cfg := &packages.Config{ - Mode: packages.NeedTypes | packages.NeedImports | packages.NeedSyntax, + Mode: packages.NeedTypes | packages.NeedImports | packages.NeedSyntax | packages.NeedTypesInfo, } var pkgs []*packages.Package @@ -162,6 +185,70 @@ func (l *Loader) load(path string, mode LoadMode) ([]*packages.Package, error) { return pkgs, nil } +func loadFromInterface(name string, interfaceType *types.Interface) (*Result, error) { + result := &Result{ + Name: name, + } + for m := 0; m < interfaceType.NumMethods(); m++ { + meth := interfaceType.Method(m) + mm, err := method.ParseMethod(meth.Name(), meth.Type().(*types.Signature)) + if err != nil { + return nil, err + } + result.Methods = append(result.Methods, mm) + } + return result, nil +} + +func loadFromStruct(f *ast.File, name string, info *types.Info) (*Result, error) { + result := &Result{ + Name: name, + } + var firstError error + ast.Inspect(f, func(node ast.Node) bool { + fn, ok := node.(*ast.FuncDecl) + if !ok { + return true + } + if fn.Recv == nil { + return true + } + for _, l := range fn.Recv.List { + var ident *ast.Ident + switch t := l.Type.(type) { + case *ast.StarExpr: + ident = t.X.(*ast.Ident) + case *ast.Ident: + ident = t + } + + if ident == nil || ident.Name != name { + continue + } + + // Ignore unexported methods + if !unicode.IsUpper(rune(fn.Name.Name[0])) { + log.Debug().Msgf("Ignoring function %s as it is unexported", fn.Name.Name) + continue + } + + meth, err := method.ParseMethod(fn.Name.Name, info.Defs[fn.Name].Type().(*types.Signature)) + if err != nil { + if firstError == nil { + firstError = err + } + return false + } + result.Methods = append(result.Methods, meth) + } + return true + }) + if firstError != nil { + return nil, firstError + } + return result, nil +} + func extractPackageErrors(pkgs []*packages.Package) error { var errors []error packages.Visit(pkgs, nil, func(pkg *packages.Package) { @@ -178,13 +265,20 @@ func extractPackageErrors(pkgs []*packages.Package) error { } func exprToFilter(expressions []string) (*regexp.Regexp, error) { - expression := strings.Join(expressions, "|") - if strings.ContainsAny(expression, regexChars) { - filter, err := regexp.Compile(expression) - if err != nil { - return nil, fmt.Errorf("failed to compile expression %q; error=%w", expression, err) + var filter []string + for _, expr := range expressions { + if strings.ContainsAny(expr, regexChars) { + // RegEx expression + filter = append(filter, expr) + } else { + // Exact match + filter = append(filter, fmt.Sprintf("\\b%s\\b", expr)) } - return filter, nil } - return regexp.MustCompile(fmt.Sprintf("\\b%s\\b", expression)), nil + expression := strings.Join(filter, "|") + reFilter, err := regexp.Compile(expression) + if err != nil { + return nil, fmt.Errorf("failed to compile expression %q; error=%w", expression, err) + } + return reFilter, nil } diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index b9d7dad..96cb826 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -10,10 +10,11 @@ import ( ) func TestLoad(t *testing.T) { - exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ - Name: "github.com/csueiras", - Files: map[string]interface{}{ - "fake/fake.go": `package fake + t.Run("Load Interface", func(t *testing.T) { + exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ + Name: "github.com/csueiras", + Files: map[string]interface{}{ + "fake/fake.go": `package fake import "context" @@ -21,18 +22,52 @@ type Service interface { GetUserID(ctx context.Context, userID string) (string, error) } `, - }}}) - defer exported.Cleanup() + }}}) + defer exported.Cleanup() - l := loader.NewLoader(func(cfg *packages.Config, patterns ...string) ([]*packages.Package, error) { - exported.Config.Mode = cfg.Mode - return packages.Load(exported.Config, patterns...) + l := loader.NewLoader(func(cfg *packages.Config, patterns ...string) ([]*packages.Package, error) { + exported.Config.Mode = cfg.Mode + return packages.Load(exported.Config, patterns...) + }) + + svc, err := l.LoadOne("github.com/csueiras/fake", "Service", loader.PackageLoadMode) + require.NoError(t, err) + require.NotNil(t, svc) + require.Equal(t, "Service", svc.Name) + require.Equal(t, 1, len(svc.Methods)) + require.Equal(t, "GetUserID", svc.Methods[0].Name) }) - svc, err := l.LoadOne("github.com/csueiras/fake", "Service", loader.PackageLoadMode) - require.NoError(t, err) - require.NotNil(t, svc) - require.Equal(t, "interface{GetUserID(ctx context.Context, userID string) (string, error)}", svc.InterfaceType.String()) + t.Run("Load Struct", func(t *testing.T) { + exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ + Name: "github.com/csueiras", + Files: map[string]interface{}{ + "fake/fake.go": `package fake + +import "context" + +type service struct { +} + +func (s *service) GetUserID(ctx context.Context, userID string) (string, error) { + return "My User", nil +} +`, + }}}) + defer exported.Cleanup() + + l := loader.NewLoader(func(cfg *packages.Config, patterns ...string) ([]*packages.Package, error) { + exported.Config.Mode = cfg.Mode + return packages.Load(exported.Config, patterns...) + }) + + svc, err := l.LoadOne("github.com/csueiras/fake", "service", loader.PackageLoadMode) + require.NoError(t, err) + require.NotNil(t, svc) + require.Equal(t, "service", svc.Name) + require.Equal(t, 1, len(svc.Methods)) + require.Equal(t, "GetUserID", svc.Methods[0].Name) + }) } func TestLoadMatched(t *testing.T) { @@ -55,7 +90,7 @@ type unexportedService interface { HelloWorld() } -type NotAnInterface struct { +type StructWithNoMethods struct { SomeField string } `, @@ -72,12 +107,18 @@ type NotAnInterface struct { require.NoError(t, err) require.NotNil(t, results) require.Equal(t, 3, len(results)) - require.NotNil(t, results["UserService"]) - require.Equal(t, "interface{GetUserID(ctx context.Context, userID string) (string, error)}", results["UserService"].InterfaceType.String()) - require.NotNil(t, results["HelloWorldService"]) - require.Equal(t, "interface{Hello(ctx context.Context, name string) error}", results["HelloWorldService"].InterfaceType.String()) - require.NotNil(t, results["unexportedService"]) - require.Equal(t, "interface{HelloWorld()}", results["unexportedService"].InterfaceType.String()) + + require.Equal(t, "UserService", results["UserService"].Name) + require.Equal(t, 1, len(results["UserService"].Methods)) + require.Equal(t, "GetUserID", results["UserService"].Methods[0].Name) + + require.Equal(t, "HelloWorldService", results["HelloWorldService"].Name) + require.Equal(t, 1, len(results["HelloWorldService"].Methods)) + require.Equal(t, "Hello", results["HelloWorldService"].Methods[0].Name) + + require.Equal(t, "unexportedService", results["unexportedService"].Name) + require.Equal(t, 1, len(results["unexportedService"].Methods)) + require.Equal(t, "HelloWorld", results["unexportedService"].Methods[0].Name) }) t.Run("Multiple RegEx Expressions", func(t *testing.T) { @@ -90,10 +131,14 @@ type NotAnInterface struct { require.NoError(t, err) require.NotNil(t, results) require.Equal(t, 2, len(results)) - require.NotNil(t, results["UserService"]) - require.Equal(t, "interface{GetUserID(ctx context.Context, userID string) (string, error)}", results["UserService"].InterfaceType.String()) - require.NotNil(t, results["HelloWorldService"]) - require.Equal(t, "interface{Hello(ctx context.Context, name string) error}", results["HelloWorldService"].InterfaceType.String()) + + require.Equal(t, "UserService", results["UserService"].Name) + require.Equal(t, 1, len(results["UserService"].Methods)) + require.Equal(t, "GetUserID", results["UserService"].Methods[0].Name) + + require.Equal(t, "HelloWorldService", results["HelloWorldService"].Name) + require.Equal(t, 1, len(results["HelloWorldService"].Methods)) + require.Equal(t, "Hello", results["HelloWorldService"].Methods[0].Name) }) t.Run("Exact Match", func(t *testing.T) { @@ -106,8 +151,9 @@ type NotAnInterface struct { require.NoError(t, err) require.NotNil(t, results) require.Equal(t, 1, len(results)) - require.NotNil(t, results["HelloWorldService"]) - require.Equal(t, "interface{Hello(ctx context.Context, name string) error}", results["HelloWorldService"].InterfaceType.String()) + require.Equal(t, "HelloWorldService", results["HelloWorldService"].Name) + require.Equal(t, 1, len(results["HelloWorldService"].Methods)) + require.Equal(t, "Hello", results["HelloWorldService"].Methods[0].Name) }) t.Run("Exact Match: No Match", func(t *testing.T) { @@ -128,13 +174,17 @@ type NotAnInterface struct { return packages.Load(exported.Config, patterns...) }) - results, err := l.LoadMatched("github.com/csueiras/fake", []string{"UserService", "HelloWorldService", "NotAnInterface"}, loader.PackageLoadMode) + results, err := l.LoadMatched("github.com/csueiras/fake", []string{"UserService", "HelloWorldService", "StructWithNoMethods"}, loader.PackageLoadMode) require.NoError(t, err) require.NotNil(t, results) require.Equal(t, 2, len(results)) - require.NotNil(t, results["UserService"]) - require.Equal(t, "interface{GetUserID(ctx context.Context, userID string) (string, error)}", results["UserService"].InterfaceType.String()) - require.NotNil(t, results["HelloWorldService"]) - require.Equal(t, "interface{Hello(ctx context.Context, name string) error}", results["HelloWorldService"].InterfaceType.String()) + + require.Equal(t, "UserService", results["UserService"].Name) + require.Equal(t, 1, len(results["UserService"].Methods)) + require.Equal(t, "GetUserID", results["UserService"].Methods[0].Name) + + require.Equal(t, "HelloWorldService", results["HelloWorldService"].Name) + require.Equal(t, 1, len(results["HelloWorldService"].Methods)) + require.Equal(t, "Hello", results["HelloWorldService"].Methods[0].Name) }) } diff --git a/internal/testpkg/teststruct.go b/internal/testpkg/teststruct.go new file mode 100644 index 0000000..3bc3258 --- /dev/null +++ b/internal/testpkg/teststruct.go @@ -0,0 +1,20 @@ +package testpkg + +import "context" + +type service struct { +} + +func (s *service) unexportedOperation(arg string) (string, error) { + return arg, nil +} + +func (s *service) GetUserByID(_ context.Context, _ string) (string, error) { + return "Christian", nil +} + +type anotherService struct{} + +func (a anotherService) DoOperation() { + +} diff --git a/internal/types/types.go b/internal/types/types.go new file mode 100644 index 0000000..912f8c4 --- /dev/null +++ b/internal/types/types.go @@ -0,0 +1,60 @@ +package types + +import ( + "go/types" + "golang.org/x/tools/go/packages" +) + +// ErrType is the types.Type for the error interface +var ErrType types.Type + +// ContextType is the types.Type for the context.Context interface +var ContextType *types.Interface + +func init() { + errType := types.NewInterfaceType([]*types.Func{ + types.NewFunc(0, nil, "Error", + types.NewSignature( + nil, + types.NewTuple(), + types.NewTuple(types.NewParam(0, nil, "", types.Typ[types.String])), + false, + ), + ), + }, nil) + errType.Complete() + ErrType = types.NewNamed(types.NewTypeName(0, nil, "error", nil), errType, nil) + + // Load the type definition for the Context type + ctxPkg, err := packages.Load(&packages.Config{ + Mode: packages.NeedTypes | packages.NeedImports | packages.NeedSyntax | packages.NeedTypesInfo, + }, "context") + if err != nil { + panic(err) + } + ContextType = ctxPkg[0].Types. + Scope(). + Lookup("Context"). + Type().(*types.Named). + Underlying(). + (*types.Interface) +} + +// IsErrorType determines if the given type implements the Error interface +func IsErrorType(t types.Type) bool { + if t == nil { + return false + } + return types.Implements(t, ErrType.Underlying().(*types.Interface)) +} + +// IsContextType determines if the given type is context.Context +func IsContextType(t types.Type) bool { + if t == nil { + return false + } + if t.String() == "context.Context" { + return true + } + return types.Implements(t, ContextType) +} From e7fe85c1f65c8ea77a59f6f613d65aa39b730743 Mon Sep 17 00:00:00 2001 From: Christian Sueiras Date: Sun, 7 Mar 2021 18:47:12 -0500 Subject: [PATCH 2/5] Remove commented out line --- internal/generator/retryable/retryable.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/generator/retryable/retryable.go b/internal/generator/retryable/retryable.go index cdf3b9c..8b5d33e 100644 --- a/internal/generator/retryable/retryable.go +++ b/internal/generator/retryable/retryable.go @@ -13,8 +13,7 @@ const ( // Retryable is a code generator for a method that can be retried on error type Retryable struct { - method *method.Method - //originalTypeName string + method *method.Method structName string receiverName string } From 53288fd9c26b650531ec67949c02a9a97f90b122 Mon Sep 17 00:00:00 2001 From: Christian Sueiras Date: Sun, 7 Mar 2021 18:51:00 -0500 Subject: [PATCH 3/5] disable linting testdata --- internal/testpkg/teststruct.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/testpkg/teststruct.go b/internal/testpkg/teststruct.go index 3bc3258..0c3ca22 100644 --- a/internal/testpkg/teststruct.go +++ b/internal/testpkg/teststruct.go @@ -1,3 +1,5 @@ +//nolint + package testpkg import "context" @@ -16,5 +18,4 @@ func (s *service) GetUserByID(_ context.Context, _ string) (string, error) { type anotherService struct{} func (a anotherService) DoOperation() { - } From 49350d954ff73ad9f60f912c0f5fbd8ddd9d95c9 Mon Sep 17 00:00:00 2001 From: Christian Sueiras Date: Sun, 7 Mar 2021 18:52:54 -0500 Subject: [PATCH 4/5] Remove space --- internal/testpkg/teststruct.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/testpkg/teststruct.go b/internal/testpkg/teststruct.go index 0c3ca22..b4ea6f3 100644 --- a/internal/testpkg/teststruct.go +++ b/internal/testpkg/teststruct.go @@ -1,5 +1,4 @@ //nolint - package testpkg import "context" From b23fdb4660251872112cf29193c58b094fc26059 Mon Sep 17 00:00:00 2001 From: Christian Sueiras Date: Sun, 7 Mar 2021 19:00:14 -0500 Subject: [PATCH 5/5] Update docs to address support for structs --- README.md | 44 +++++++++++++++++++++++++++++--------- cmd/reinforcer/cmd/root.go | 6 +++--- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0063f2a..a80e522 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # reinforcer + ![Tests](https://github.com/csueiras/reinforcer/workflows/run%20tests/badge.svg?branch=develop) [![Coverage Status](https://coveralls.io/repos/github/csueiras/reinforcer/badge.svg?branch=develop)](https://coveralls.io/github/csueiras/reinforcer?branch=develop) [![Go Report Card](https://goreportcard.com/badge/github.com/csueiras/reinforcer)](https://goreportcard.com/report/github.com/csueiras/reinforcer) @@ -9,13 +10,15 @@ Reinforcer is a code generation tool that automates middleware injection in a pr implementation, this aids in building more resilient code as you can use common resiliency patterns in the middlewares such as circuit breakers, retrying, timeouts and others. -**NOTE:** _While version is < 1.0.0 the APIs might dramatically change between minor versions, any breaking changes will be enumerated here starting with version 0.7.0 and forward._ +**NOTE:** _While version is < 1.0.0 the APIs might dramatically change between minor versions, any breaking changes will +be enumerated here starting with version 0.7.0 and forward._ ## Install ### Releases -Visit the [releases page](https://github.com/csueiras/reinforcer/releases) for pre-built binaries for OS X, Linux and Windows. +Visit the [releases page](https://github.com/csueiras/reinforcer/releases) for pre-built binaries for OS X, Linux and +Windows. ### Docker @@ -43,22 +46,26 @@ brew upgrade csueiras/reinforcer/reinforcer ### CLI -Generate reinforced code for all exported interfaces: +Generate reinforced code for all exported interfaces and structs: + ``` reinforcer --src=./service.go --targetall --outputdir=./reinforced ``` Generate reinforced code using regex: + ``` reinforcer --src=./service.go --target='.*Service' --outputdir=./reinforced ``` Generate reinforced code using an exact match: + ``` reinforcer --src=./service.go --target=MyService --outputdir=./reinforced ``` For more options: + ``` reinforcer --help ``` @@ -79,15 +86,14 @@ Flags: -p, --outpkg string name of generated package (default "reinforced") -o, --outputdir string directory to write the generated code to (default "./reinforced") -q, --silent disables logging. Mutually exclusive with the debug flag. - -s, --src strings source files to scan for the target interface. If unspecified the file pointed by the env variable GOFILE will be used. - -t, --target strings name of target type or regex to match interface names with - -a, --targetall codegen for all exported interfaces discovered. This option is mutually exclusive with the target option. + -s, --src strings source files to scan for the target interface or struct. If unspecified the file pointed by the env variable GOFILE will be used. + -t, --target strings name of target type or regex to match interface or struct names with + -a, --targetall codegen for all exported interfaces/structs discovered. This option is mutually exclusive with the target option. -v, --version show reinforcer's version ``` ### Using Reinforced Code - 1. Describe the target that you want to generate code for: ``` @@ -96,7 +102,25 @@ type Client interface { } ``` -2. Create the runner/middleware factory with the middlewares you want to inject into the generated code: +Or from a struct: + +``` +type Client struct { +} + +func (c *Client) DoOperation(ctx context.Context, arg string) error { + // ... + return nil +} +``` + +2. Generate the reinforcer code: + +``` +reinforcer --debug --src='./client.go' --target=Client --outputdir=./reinforced +``` + +3. Create the runner/middleware factory with the middlewares you want to inject into the generated code: ``` r := runner.NewFactory( @@ -108,7 +132,7 @@ r := runner.NewFactory( ) ``` -3. Optionally create your predicate for errors that shouldn't be retried +4. Optionally create your predicate for errors that shouldn't be retried ``` // shouldRetryErrPredicate is a predicate that ignores the "NotFound" errors emited by the DoOperation in Client. All other errors @@ -121,7 +145,7 @@ shouldRetryErrPredicate := func(method string, err error) bool { } ``` -4. Wrap the "real"/unrealiable implementation in the generated code: +5. Wrap the "real"/unrealiable implementation in the generated code: ``` c := client.NewClient(...) diff --git a/cmd/reinforcer/cmd/root.go b/cmd/reinforcer/cmd/root.go index 0f9ef4d..5bd09bc 100644 --- a/cmd/reinforcer/cmd/root.go +++ b/cmd/reinforcer/cmd/root.go @@ -154,9 +154,9 @@ such as circuit breaker, retries, timeouts, etc. flags.BoolP("version", "v", false, "show reinforcer's version") flags.BoolP("debug", "d", false, "enables debug logs") flags.BoolP("silent", "q", false, "disables logging. Mutually exclusive with the debug flag.") - flags.StringSliceP("src", "s", nil, "source files to scan for the target interface. If unspecified the file pointed by the env variable GOFILE will be used.") - flags.StringSliceP("target", "t", []string{}, "name of target type or regex to match interface names with") - flags.BoolP("targetall", "a", false, "codegen for all exported interfaces discovered. This option is mutually exclusive with the target option.") + flags.StringSliceP("src", "s", nil, "source files to scan for the target interface or struct. If unspecified the file pointed by the env variable GOFILE will be used.") + flags.StringSliceP("target", "t", []string{}, "name of target type or regex to match interface or struct names with") + flags.BoolP("targetall", "a", false, "codegen for all exported interfaces/structs discovered. This option is mutually exclusive with the target option.") flags.StringP("outputdir", "o", "./reinforced", "directory to write the generated code to") flags.StringP("outpkg", "p", "reinforced", "name of generated package") flags.BoolP("ignorenoret", "i", false, "ignores methods that don't return anything (they won't be wrapped in the middleware). By default they'll be wrapped in a middleware and if the middleware emits an error the call will panic.")