diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..efca92444b --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,3 @@ +# For more information, please refer to https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners + +* @cloudwego/Kitex-reviewers @cloudwego/Kitex-approvers @cloudwego/Kitex-maintainers diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 9f9df42c89..c156a606d0 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -14,7 +14,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Check Spell - uses: crate-ci/typos@master + uses: crate-ci/typos@v1.13.14 staticcheck: runs-on: [ self-hosted, X64 ] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e60bc339d6..77b272e27c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,4 +48,4 @@ jobs: with: go-version: ${{ matrix.go }} - name: Unit Test - run: go test -gcflags=-l -race -covermode=atomic -coverprofile=coverage.txt ./... + run: go test -gcflags=-l -race -covermode=atomic ./... diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a487f3bb81..176fe35906 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,7 +31,7 @@ Before you submit your Pull Request (PR) consider the following guidelines: ``` git checkout -b my-fix-branch develop ``` -6. Create your patch, including appropriate test cases. +6. Create your patch, including appropriate test cases. Please refer to [Go-UT](https://pkg.go.dev/testing#pkg-overview) for writing guides. [Go-Mock](https://github.com/golang/mock) is recommended to mock interface, please refer to internal/mocks/readme.md for more details, and [Mockey](https://github.com/bytedance/mockey) is recommended to mock functions, please refer to its readme doc for specific usage. 7. Follow our [Style Guides](#code-style-guides). 8. Commit your changes using a descriptive commit message that follows [AngularJS Git Commit Message Conventions](https://docs.google.com/document/d/1QrDFcIiPjSLDn3EL15IJygNPiHORgU1_OOAqWjiDU5Y/edit). Adherence to these conventions is necessary because release notes are automatically generated from these messages. diff --git a/README.md b/README.md index e035c7a68b..2b8895735e 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Kitex has built-in code generation tools that support generating **Thrift**, **P - **Basic Features** - Including Message Type, Supported Protocols, Directly Invoke, Connection Pool, Timeout Control, Request Retry, LoadBalancer, Circuit Breaker, Rate Limiting, Instrumentation Control, Logging and HttpResolver.[[more]](https://www.cloudwego.io/docs/tutorials/basic-feature/) + Including Message Type, Supported Protocols, Directly Invoke, Connection Pool, Timeout Control, Request Retry, LoadBalancer, Circuit Breaker, Rate Limiting, Instrumentation Control, Logging and HttpResolver.[[more]](https://www.cloudwego.io/docs/kitex/tutorials/basic-feature/) - **Governance Features** diff --git a/client/callopt/options.go b/client/callopt/options.go index d8c78e258a..723e3de245 100644 --- a/client/callopt/options.go +++ b/client/callopt/options.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/kitex/internal/client" "github.com/cloudwego/kitex/pkg/discovery" + "github.com/cloudwego/kitex/pkg/fallback" "github.com/cloudwego/kitex/pkg/http" "github.com/cloudwego/kitex/pkg/retry" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -45,6 +46,7 @@ type CallOptions struct { // export field for using in client RetryPolicy retry.Policy + Fallback *fallback.Policy } func newOptions() interface{} { @@ -63,6 +65,7 @@ func (co *CallOptions) Recycle() { co.configs = nil co.svr = nil co.RetryPolicy = retry.Policy{} + co.Fallback = nil co.locks.Zero() callOptionsPool.Put(co) } @@ -171,14 +174,14 @@ func WithTag(key, val string) Option { // WithRetryPolicy sets the retry policy for a RPC call. // Build retry.Policy with retry.BuildFailurePolicy or retry.BuildBackupRequest instead of building retry.Policy directly. -// Below is use demo, eg: +// Demos are provided below: // // demo1. call with failure retry policy, default retry error is Timeout -// resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildFailurePolicy(retry.NewFailurePolicy()))) +// `resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildFailurePolicy(retry.NewFailurePolicy())))` // demo2. call with backup request policy -// bp := retry.NewBackupPolicy(10) +// `bp := retry.NewBackupPolicy(10) // bp.WithMaxRetryTimes(1) -// resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildBackupRequest(bp))) +// resp, err := cli.Mock(ctx, req, callopt.WithRetryPolicy(retry.BuildBackupRequest(bp)))` func WithRetryPolicy(p retry.Policy) Option { return Option{f: func(o *CallOptions, di *strings.Builder) { if !p.Enable { @@ -193,6 +196,23 @@ func WithRetryPolicy(p retry.Policy) Option { }} } +// WithFallback is used to set the fallback policy for a RPC call. +// Demos are provided below: +// +// demo1. call with fallback for error +// `resp, err := cli.Mock(ctx, req, callopt.WithFallback(fallback.ErrorFallback(yourFBFunc))` +// demo2. call with fallback for error and enable reportAsFallback, which sets reportAsFallback to be true and will do report(metric) as Fallback result +// `resp, err := cli.Mock(ctx, req, callopt.WithFallback(fallback.ErrorFallback(yourFBFunc).EnableReportAsFallback())` +func WithFallback(fb *fallback.Policy) Option { + return Option{f: func(o *CallOptions, di *strings.Builder) { + if !fallback.IsPolicyValid(fb) { + return + } + di.WriteString("WithFallback") + o.Fallback = fb + }} +} + // Apply applies call options to the rpcinfo.RPCConfig and internal.RemoteInfo of kitex client. // The return value records the name and arguments of each option. // This function is for internal purpose only. diff --git a/client/callopt/options_test.go b/client/callopt/options_test.go index 7b9835494c..f768d4d7f3 100644 --- a/client/callopt/options_test.go +++ b/client/callopt/options_test.go @@ -17,13 +17,16 @@ package callopt import ( + "context" "fmt" "testing" "time" "github.com/cloudwego/kitex/internal/client" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/fallback" "github.com/cloudwego/kitex/pkg/http" + "github.com/cloudwego/kitex/pkg/retry" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" ) @@ -89,4 +92,27 @@ func TestApply(t *testing.T) { v, exist := remoteInfo.Tag(mockKey) test.Assert(t, exist) test.Assert(t, v == mockVal, v) + + // WithRetryPolicy + option = WithRetryPolicy(retry.BuildFailurePolicy(retry.NewFailurePolicy())) + _, co := Apply([]Option{option}, rpcConfig, remoteInfo, client.NewConfigLocks(), http.NewDefaultResolver()) + test.Assert(t, co.RetryPolicy.Enable) + test.Assert(t, co.RetryPolicy.FailurePolicy != nil) + + // WithRetryPolicy pass empty struct + option = WithRetryPolicy(retry.Policy{}) + _, co = Apply([]Option{option}, rpcConfig, remoteInfo, client.NewConfigLocks(), http.NewDefaultResolver()) + test.Assert(t, !co.RetryPolicy.Enable) + + // WithFallback + option = WithFallback(fallback.ErrorFallback(fallback.UnwrapHelper(func(ctx context.Context, req, resp interface{}, err error) (fbResp interface{}, fbErr error) { + return + })).EnableReportAsFallback()) + _, co = Apply([]Option{option}, rpcConfig, remoteInfo, client.NewConfigLocks(), http.NewDefaultResolver()) + test.Assert(t, co.Fallback != nil) + + // WithFallback pass nil + option = WithFallback(nil) + _, co = Apply([]Option{option}, rpcConfig, remoteInfo, client.NewConfigLocks(), http.NewDefaultResolver()) + test.Assert(t, co.Fallback == nil) } diff --git a/client/client.go b/client/client.go index 5ec321c9ae..0c0bca3e8b 100644 --- a/client/client.go +++ b/client/client.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "runtime" + "runtime/debug" "strconv" "sync/atomic" @@ -34,6 +35,7 @@ import ( "github.com/cloudwego/kitex/pkg/discovery" "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/event" + "github.com/cloudwego/kitex/pkg/fallback" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/loadbalance" @@ -244,11 +246,14 @@ func (kc *kClient) initLBCache() error { NameFunc: func() string { return "no_resolver" }, } } + // because we cannot ensure that user's custom loadbalancer is cacheable, we need to disable it here + cacheOpts := lbcache.Options{DiagnosisService: kc.opt.DebugService, Cacheable: false} balancer := kc.opt.Balancer if balancer == nil { + // default internal lb balancer is cacheable + cacheOpts.Cacheable = true balancer = loadbalance.NewWeightedBalancer() } - cacheOpts := lbcache.Options{DiagnosisService: kc.opt.DebugService} if kc.opt.BalancerCacheOpt != nil { cacheOpts = *kc.opt.BalancerCacheOpt } @@ -365,48 +370,67 @@ func (kc *kClient) Call(ctx context.Context, method string, request, response in ctx = kc.opt.TracerCtl.DoStart(ctx, ri) var callOptRetry retry.Policy - if callOpts != nil && callOpts.RetryPolicy.Enable { - callOptRetry = callOpts.RetryPolicy - } - if kc.opt.RetryContainer == nil { - if callOptRetry.Enable { - // setup retry in callopt - kc.opt.RetryContainer = retry.NewRetryContainer() - } else { - err := kc.eps(ctx, request, response) - kc.opt.TracerCtl.DoFinish(ctx, ri, err) - if err == nil { - err = ri.Invocation().BizStatusErr() - rpcinfo.PutRPCInfo(ri) - } - return err + var callOptFallback *fallback.Policy + if callOpts != nil { + callOptFallback = callOpts.Fallback + if callOpts.RetryPolicy.Enable { + callOptRetry = callOpts.RetryPolicy } } + if kc.opt.RetryContainer == nil && callOptRetry.Enable { + // setup retry in callopt + kc.opt.RetryContainer = retry.NewRetryContainer() + } - var callTimes int32 - var prevRI rpcinfo.RPCInfo - recycleRI, err := kc.opt.RetryContainer.WithRetryIfNeeded(ctx, callOptRetry, func(ctx context.Context, r retry.Retryer) (rpcinfo.RPCInfo, interface{}, error) { - currCallTimes := int(atomic.AddInt32(&callTimes, 1)) - retryCtx := ctx - cRI := ri - if currCallTimes > 1 { - retryCtx, cRI, _ = kc.initRPCInfo(ctx, method) - retryCtx = metainfo.WithPersistentValue(retryCtx, retry.TransitKey, strconv.Itoa(currCallTimes-1)) - if prevRI == nil { - prevRI = ri + var err error + var recycleRI bool + if kc.opt.RetryContainer == nil { + err = kc.eps(ctx, request, response) + if err == nil { + recycleRI = true + } + } else { + var callTimes int32 + // prevRI represents a value of rpcinfo.RPCInfo type. + var prevRI atomic.Value + recycleRI, err = kc.opt.RetryContainer.WithRetryIfNeeded(ctx, callOptRetry, func(ctx context.Context, r retry.Retryer) (rpcinfo.RPCInfo, interface{}, error) { + currCallTimes := int(atomic.AddInt32(&callTimes, 1)) + retryCtx := ctx + cRI := ri + if currCallTimes > 1 { + retryCtx, cRI, _ = kc.initRPCInfo(ctx, method) + retryCtx = metainfo.WithPersistentValue(retryCtx, retry.TransitKey, strconv.Itoa(currCallTimes-1)) + if prevRI.Load() == nil { + prevRI.Store(ri) + } + r.Prepare(retryCtx, prevRI.Load().(rpcinfo.RPCInfo), cRI) + prevRI.Store(cRI) } - r.Prepare(retryCtx, prevRI, cRI) - prevRI = cRI + err := kc.eps(retryCtx, request, response) + return cRI, response, err + }, ri, request) + } + + // do fallback if with setup + fallback, hasFallback := getFallbackPolicy(callOptFallback, kc.opt.Fallback) + var fbErr error + reportErr := err + if hasFallback { + reportAsFB := false + // Notice: If rpc err is nil, rpcStatAsFB will always be false, even if it's set to true by user. + fbErr, reportAsFB = fallback.DoIfNeeded(ctx, ri, request, response, err) + if reportAsFB { + reportErr = fbErr } - err := kc.eps(retryCtx, request, response) - return cRI, response, err - }, ri, request) + err = fbErr + } - kc.opt.TracerCtl.DoFinish(ctx, ri, err) + kc.opt.TracerCtl.DoFinish(ctx, ri, reportErr) callOpts.Recycle() - if err == nil { + if err == nil && !hasFallback { err = ri.Invocation().BizStatusErr() } + if recycleRI { // why need check recycleRI to decide if recycle RPCInfo? // 1. no retry, rpc timeout happen will cause panic when response return @@ -503,6 +527,11 @@ func (kc *kClient) invokeHandleEndpoint() (endpoint.Endpoint, error) { // Close is not concurrency safe. func (kc *kClient) Close() error { + defer func() { + if err := recover(); err != nil { + klog.Warnf("KITEX: panic when close client, error=%s, stack=%s", err, string(debug.Stack())) + } + }() if kc.closed { return nil } @@ -631,3 +660,14 @@ func (kc *kClient) warmingUp() error { return nil } + +// return fallback policy from call option and client option. +func getFallbackPolicy(callOptFB, cliOptFB *fallback.Policy) (fb *fallback.Policy, hasFallback bool) { + if callOptFB != nil { + return callOptFB, true + } + if cliOptFB != nil { + return cliOptFB, true + } + return nil, false +} diff --git a/client/client_test.go b/client/client_test.go index 919bc79a5e..0cae9b98e2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -34,8 +34,11 @@ import ( "github.com/cloudwego/kitex/internal/mocks" mocksnetpoll "github.com/cloudwego/kitex/internal/mocks/netpoll" mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" + mocksstats "github.com/cloudwego/kitex/internal/mocks/stats" + mockthrift "github.com/cloudwego/kitex/internal/mocks/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/fallback" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/retry" @@ -507,10 +510,10 @@ func TestRetry(t *testing.T) { defer ctrl.Finish() var count int32 - md := func(next endpoint.Endpoint) endpoint.Endpoint { + sleepMW := func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, req, resp interface{}) (err error) { if atomic.CompareAndSwapInt32(&count, 0, 1) { - time.Sleep(time.Second) + time.Sleep(300 * time.Millisecond) } return nil } @@ -518,8 +521,8 @@ func TestRetry(t *testing.T) { // should timeout cli := newMockClient(t, ctrl, - WithMiddleware(md), - WithRPCTimeout(500*time.Millisecond), + WithMiddleware(sleepMW), + WithRPCTimeout(100*time.Millisecond), WithFailureRetry(&retry.FailurePolicy{ StopPolicy: retry.StopPolicy{ MaxRetryTimes: 3, @@ -535,6 +538,258 @@ func TestRetry(t *testing.T) { test.Assert(t, err == nil, err) } +func TestRetryWithResultRetry(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockErr := errors.New("mock") + retryWithMockErr := false + var count int32 + errMW := func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, req, resp interface{}) (err error) { + if atomic.CompareAndSwapInt32(&count, 0, 1) { + return mockErr + } + return nil + } + } + errRetryFunc := func(err error, ri rpcinfo.RPCInfo) bool { + if errors.Is(err, mockErr) { + retryWithMockErr = true + return true + } + return false + } + + // should timeout + cli := newMockClient(t, ctrl, + WithMiddleware(errMW), + WithRPCTimeout(100*time.Millisecond), + WithFailureRetry(&retry.FailurePolicy{ + StopPolicy: retry.StopPolicy{ + MaxRetryTimes: 3, + CBPolicy: retry.CBPolicy{ + ErrorRate: 0.1, + }, + }, + RetrySameNode: true, + ShouldResultRetry: &retry.ShouldResultRetry{ErrorRetry: errRetryFunc}, + })) + mtd := mocks.MockMethod + req, res := new(MockTStruct), new(MockTStruct) + err := cli.Call(context.Background(), mtd, req, res) + test.Assert(t, err == nil, err) + test.Assert(t, retryWithMockErr) +} + +func TestFallbackForError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + isTimeout := false + isMockErr := false + isBizErr := false + errForReportIsNil := false + var rpcResult interface{} + + // prepare mock data + retStr := "success" + sleepMW := func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, req, resp interface{}) (err error) { + time.Sleep(300 * time.Millisecond) + return nil + } + } + mockErr := errors.New("mock") + mockErrMW := func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, req, resp interface{}) (err error) { + return mockErr + } + } + bizErrMW := func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, req, resp interface{}) (err error) { + if setter, ok := rpcinfo.GetRPCInfo(ctx).Invocation().(rpcinfo.InvocationSetter); ok { + setter.SetBizStatusErr(kerrors.NewBizStatusError(100, mockErr.Error())) + } + return nil + } + } + + mockTracer := mocksstats.NewMockTracer(ctrl) + mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { + return ctx + }).AnyTimes() + mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) { + if rpcinfo.GetRPCInfo(ctx).Stats().Error() == nil { + errForReportIsNil = true + } else { + errForReportIsNil = false + } + }).AnyTimes() + newFallback := func(realReqResp bool, testCase string) *fallback.Policy { + var fallbackFunc fallback.Func + checkErr := func(err error) { + isTimeout, isMockErr, isBizErr = false, false, false + if errors.Is(err, kerrors.ErrRPCTimeout) { + isTimeout = true + } + if errors.Is(err, mockErr) { + isMockErr = true + } + if _, ok := kerrors.FromBizStatusError(err); ok { + isBizErr = true + } + } + if realReqResp { + fallbackFunc = fallback.UnwrapHelper(func(ctx context.Context, req, resp interface{}, err error) (fbResp interface{}, fbErr error) { + checkErr(err) + _, ok := req.(*mockthrift.MockReq) + test.Assert(t, ok, testCase) + return &retStr, nil + }) + } else { + fallbackFunc = func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + checkErr(err) + _, ok := args.(*mockthrift.MockTestArgs) + test.Assert(t, ok, testCase) + result.SetSuccess(&retStr) + return nil + } + } + return fallback.NewFallbackPolicy(fallbackFunc) + } + + // case 1: fallback for timeout, return nil err, but report original err + cli := newMockClient(t, ctrl, + WithMiddleware(sleepMW), + WithRPCTimeout(100*time.Millisecond), + WithFallback(newFallback(false, "case 1")), + WithTracer(mockTracer), + ) + mtd := mocks.MockMethod + rpcResult = mockthrift.NewMockTestResult() + err := cli.Call(context.Background(), mtd, mockthrift.NewMockTestArgs(), rpcResult) + test.Assert(t, err == nil, err) + test.Assert(t, *rpcResult.(utils.KitexResult).GetResult().(*string) == retStr) + test.Assert(t, isTimeout, err) + test.Assert(t, !errForReportIsNil) + + // case 2: fallback for mock error, but report original err + cli = newMockClient(t, ctrl, + WithMiddleware(mockErrMW), + WithFallback(newFallback(false, "case 2"))) + rpcResult = mockthrift.NewMockTestResult() + err = cli.Call(context.Background(), mtd, mockthrift.NewMockTestArgs(), rpcResult) + test.Assert(t, err == nil, err) + test.Assert(t, *rpcResult.(utils.KitexResult).GetResult().(*string) == retStr) + test.Assert(t, isMockErr, err) + test.Assert(t, !errForReportIsNil) + + // case 3: fallback for timeout, return nil err, and report nil err as enable reportAsFallback + cli = newMockClient(t, ctrl, + WithMiddleware(sleepMW), + WithRPCTimeout(100*time.Millisecond), + WithFallback(newFallback(true, "case 3").EnableReportAsFallback()), + WithTracer(mockTracer), + ) + // reset + isTimeout = false + errForReportIsNil = false + rpcResult = mockthrift.NewMockTestResult() + err = cli.Call(context.Background(), mtd, mockthrift.NewMockTestArgs(), rpcResult) + test.Assert(t, err == nil, err) + test.Assert(t, *rpcResult.(utils.KitexResult).GetResult().(*string) == retStr) + test.Assert(t, isTimeout, err) + test.Assert(t, errForReportIsNil) + + // case 4: no fallback, return biz err, but report nil err + cli = newMockClient(t, ctrl, + WithMiddleware(bizErrMW), + WithRPCTimeout(100*time.Millisecond), + WithTracer(mockTracer), + ) + // reset + errForReportIsNil = false + rpcResult = mockthrift.NewMockTestResult() + err = cli.Call(context.Background(), mtd, mockthrift.NewMockTestArgs(), rpcResult) + _, ok := kerrors.FromBizStatusError(err) + test.Assert(t, ok, err) + test.Assert(t, errForReportIsNil) + + // case 5: fallback for biz error, return nil err, + // and report nil err even if don't enable reportAsFallback as biz error won't report failure by design + cli = newMockClient(t, ctrl, + WithMiddleware(bizErrMW), + WithRPCTimeout(100*time.Millisecond), + WithFallback(newFallback(true, "case 5")), + WithTracer(mockTracer), + ) + // reset + isBizErr = false + errForReportIsNil = false + rpcResult = mockthrift.NewMockTestResult() + err = cli.Call(context.Background(), mtd, mockthrift.NewMockTestArgs(), rpcResult) + test.Assert(t, err == nil, err) + test.Assert(t, *rpcResult.(utils.KitexResult).GetResult().(*string) == retStr) + test.Assert(t, isBizErr, err) + test.Assert(t, errForReportIsNil) + + // case 6: fallback for timeout error, return nil err + cli = newMockClient(t, ctrl, + WithMiddleware(sleepMW), + WithRPCTimeout(100*time.Millisecond), + WithFallback(fallback.TimeoutAndCBFallback(fallback.UnwrapHelper(func(ctx context.Context, req, resp interface{}, err error) (fbResp interface{}, fbErr error) { + if errors.Is(err, kerrors.ErrRPCTimeout) { + isTimeout = true + } + return &retStr, nil + }))), + WithTracer(mockTracer), + ) + // reset + isTimeout = false + errForReportIsNil = true + rpcResult = mockthrift.NewMockTestResult() + err = cli.Call(context.Background(), mtd, mockthrift.NewMockTestArgs(), rpcResult) + test.Assert(t, err == nil, err) + test.Assert(t, *rpcResult.(utils.KitexResult).GetResult().(*string) == retStr, rpcResult) + test.Assert(t, isTimeout, err) + test.Assert(t, !errForReportIsNil) + + // case 6: use TimeoutAndCBFallback, non-timeout and non-CB error cannot do fallback + var fallbackExecuted bool + cli = newMockClient(t, ctrl, + WithMiddleware(mockErrMW), + WithRPCTimeout(100*time.Millisecond), + WithFallback(fallback.TimeoutAndCBFallback(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + fallbackExecuted = true + return + })), + WithTracer(mockTracer), + ) + // reset + errForReportIsNil = true + rpcResult = mockthrift.NewMockTestResult() + err = cli.Call(context.Background(), mtd, mockthrift.NewMockTestArgs(), rpcResult) + test.Assert(t, err != nil, err) + test.Assert(t, !fallbackExecuted) + + // case 7: fallback return nil-resp and nil-err, framework will return the original resp and err + cli = newMockClient(t, ctrl, + WithMiddleware(mockErrMW), + WithRPCTimeout(100*time.Millisecond), + WithFallback(fallback.NewFallbackPolicy(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + return + })), + WithTracer(mockTracer), + ) + mtd = mocks.MockMethod + rpcResult = mockthrift.NewMockTestResult() + err = cli.Call(context.Background(), mtd, mockthrift.NewMockTestArgs(), rpcResult) + test.Assert(t, err == mockErr) + test.Assert(t, !errForReportIsNil) +} + func TestClientFinalizer(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/client/middlewares.go b/client/middlewares.go index aff7244c81..341b4c1ba4 100644 --- a/client/middlewares.go +++ b/client/middlewares.go @@ -29,6 +29,7 @@ import ( "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/event" "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/loadbalance/lbcache" "github.com/cloudwego/kitex/pkg/proxy" "github.com/cloudwego/kitex/pkg/remote" @@ -37,6 +38,8 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" ) +const maxRetry = 6 + func newProxyMW(prx proxy.ForwardProxy) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request, response interface{}) error { @@ -75,10 +78,6 @@ func discoveryEventHandler(name string, bus event.Bus, queue event.Queue) func(d // If retryable error is encountered, it will retry until timeout or an unretryable error is returned. func newResolveMWBuilder(lbf *lbcache.BalancerFactory) endpoint.MiddlewareBuilder { return func(ctx context.Context) endpoint.Middleware { - retryable := func(err error) bool { - return errors.Is(err, kerrors.ErrGetConnection) || errors.Is(err, kerrors.ErrCircuitBreak) - } - return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request, response interface{}) error { rpcInfo := rpcinfo.GetRPCInfo(ctx) @@ -101,31 +100,39 @@ func newResolveMWBuilder(lbf *lbcache.BalancerFactory) endpoint.MiddlewareBuilde return kerrors.ErrServiceDiscovery.WithCause(err) } - picker := lb.GetPicker() - if r, ok := picker.(internal.Reusable); ok { - defer r.Recycle() - } var lastErr error - for { + for i := 0; i < maxRetry; i++ { select { case <-ctx.Done(): return kerrors.ErrRPCTimeout default: } + // we always need to get a new picker every time, because when downstream update deployment, + // we may get an old picker that include all outdated instances which will cause connect always failed. + picker := lb.GetPicker() ins := picker.Next(ctx, request) if ins == nil { - return kerrors.ErrNoMoreInstance.WithCause(fmt.Errorf("last error: %w", lastErr)) + err = kerrors.ErrNoMoreInstance.WithCause(fmt.Errorf("last error: %w", lastErr)) + } else { + remote.SetInstance(ins) + // TODO: generalize retry strategy + err = next(ctx, request, response) } - remote.SetInstance(ins) - - // TODO: generalize retry strategy - if err = next(ctx, request, response); err != nil && retryable(err) { + if r, ok := picker.(internal.Reusable); ok { + r.Recycle() + } + if err == nil { + return nil + } + if retryable(err) { lastErr = err + klog.CtxWarnf(ctx, "KITEX: auto retry retryable error, retry=%d error=%s", i+1, err.Error()) continue } return err } + return lastErr } } } @@ -178,3 +185,7 @@ func wrapInstances(insts []discovery.Instance) []*instInfo { } return instInfos } + +func retryable(err error) bool { + return errors.Is(err, kerrors.ErrGetConnection) || errors.Is(err, kerrors.ErrCircuitBreak) +} diff --git a/client/option.go b/client/option.go index 8e70dcff6e..b7f2c6b106 100644 --- a/client/option.go +++ b/client/option.go @@ -29,6 +29,7 @@ import ( "github.com/cloudwego/kitex/pkg/connpool" "github.com/cloudwego/kitex/pkg/discovery" "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/fallback" "github.com/cloudwego/kitex/pkg/http" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/loadbalance" @@ -381,14 +382,33 @@ func WithRetryMethodPolicies(mp map[string]retry.Policy) Option { // But if your retry policy is enabled by remote config, WithSpecifiedResultRetry is useful. func WithSpecifiedResultRetry(rr *retry.ShouldResultRetry) Option { return Option{F: func(o *client.Options, di *utils.Slice) { - if rr == nil { - return + if rr == nil || (rr.RespRetry == nil && rr.ErrorRetry == nil) { + panic(fmt.Errorf("WithSpecifiedResultRetry: invalid '%+v'", rr)) } di.Push(fmt.Sprintf("WithSpecifiedResultRetry(%+v)", rr)) o.RetryWithResult = rr }} } +// WithFallback is used to set the fallback policy for the client. +// Demos are provided below: +// +// demo1. fallback for error and resp +// `client.WithFallback(fallback.NewFallbackPolicy(yourFBFunc))` +// demo2. fallback for error and enable reportAsFallback, which sets reportAsFallback to be true and will do report(metric) as Fallback result +// `client.WithFallback(fallback.ErrorFallback(yourErrFBFunc).EnableReportAsFallback())` +// demo2. fallback for rpctime and circuit breaker +// `client.WithFallback(fallback.TimeoutAndCBFallback(yourErrFBFunc))` +func WithFallback(fb *fallback.Policy) Option { + return Option{F: func(o *client.Options, di *utils.Slice) { + if !fallback.IsPolicyValid(fb) { + panic(fmt.Errorf("WithFallback: invalid '%+v'", fb)) + } + di.Push(fmt.Sprintf("WithFallback(%+v)", fb)) + o.Fallback = fb + }} +} + // WithCircuitBreaker adds a circuitbreaker suite for the client. func WithCircuitBreaker(s *circuitbreak.CBSuite) Option { return Option{F: func(o *client.Options, di *utils.Slice) { diff --git a/client/option_advanced.go b/client/option_advanced.go index 084631df50..d3f7f353c3 100644 --- a/client/option_advanced.go +++ b/client/option_advanced.go @@ -20,6 +20,7 @@ package client // It is used for customized extension. import ( + "crypto/tls" "fmt" "reflect" @@ -231,3 +232,11 @@ func WithBoundHandler(h remote.BoundHandler) Option { } }} } + +// WithGRPCTLSConfig sets the TLS config for gRPC client. +func WithGRPCTLSConfig(tlsConfig *tls.Config) Option { + return Option{F: func(o *client.Options, di *utils.Slice) { + di.Push("WithGRPCTLSConfig") + o.GRPCConnectOpts.TLSConfig = tlsConfig + }} +} diff --git a/client/option_test.go b/client/option_test.go index 0475c1b4cf..bb418f86ee 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -539,7 +539,7 @@ func TestWithFirstMetaHandler(t *testing.T) { func TestWithMetaHandler(t *testing.T) { mockMetaHandler := &mock_remote.MockMetaHandler{} opts := client.NewOptions([]client.Option{WithMetaHandler(mockMetaHandler)}) - test.DeepEqual(t, opts.MetaHandlers[0], mockMetaHandler) + test.DeepEqual(t, opts.MetaHandlers[1], mockMetaHandler) } func TestWithConnPool(t *testing.T) { diff --git a/client/rpctimeout.go b/client/rpctimeout.go index a6a3fa7b4d..9cfd49fe41 100644 --- a/client/rpctimeout.go +++ b/client/rpctimeout.go @@ -63,13 +63,16 @@ func makeTimeoutErr(ctx context.Context, start time.Time, timeout time.Duration) // cancel error if ctx.Err() == context.Canceled { - return kerrors.ErrRPCTimeout.WithCause(fmt.Errorf("%s: %w", errMsg, ctx.Err())) + return kerrors.ErrRPCTimeout.WithCause(fmt.Errorf("%s: %w by business", errMsg, ctx.Err())) } if ddl, ok := ctx.Deadline(); !ok { errMsg = fmt.Sprintf("%s, %s", errMsg, "unknown error: context deadline not set?") } else { - if ddl.Before(start.Add(timeout)) { + // Go's timer implementation is not so accurate, + // so if we need to check ctx deadline earlier than our timeout, we should consider the accuracy + roundTimeout := timeout - time.Millisecond + if roundTimeout >= 0 && ddl.Before(start.Add(roundTimeout)) { errMsg = fmt.Sprintf("%s, context deadline earlier than timeout, actual=%v", errMsg, ddl.Sub(start)) } } @@ -91,8 +94,9 @@ func rpcTimeoutMW(mwCtx context.Context) endpoint.Middleware { tm := ri.Config().RPCTimeout() if tm > 0 { + tm += moreTimeout var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, tm+moreTimeout) + ctx, cancel = context.WithTimeout(ctx, tm) defer cancel() } // Fast path for ctx without timeout diff --git a/go.mod b/go.mod index 2ee73e6eb5..81ec5cec9c 100644 --- a/go.mod +++ b/go.mod @@ -4,21 +4,22 @@ go 1.13 require ( github.com/apache/thrift v0.13.0 - github.com/bytedance/gopkg v0.0.0-20220531084716-665b4f21126f - github.com/choleraehyq/pid v0.0.15 - github.com/cloudwego/fastpb v0.0.3 - github.com/cloudwego/frugal v0.1.3 - github.com/cloudwego/netpoll v0.3.1 - github.com/cloudwego/thriftgo v0.2.4 + github.com/bytedance/gopkg v0.0.0-20220817015305-b879a72dc90f + github.com/bytedance/mockey v1.2.0 + github.com/choleraehyq/pid v0.0.16 + github.com/cloudwego/fastpb v0.0.4-0.20230131074846-6fc453d58b96 + github.com/cloudwego/frugal v0.1.6 + github.com/cloudwego/netpoll v0.3.2 + github.com/cloudwego/thriftgo v0.2.8 github.com/golang/mock v1.6.0 github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 github.com/jhump/protoreflect v1.8.2 github.com/json-iterator/go v1.1.12 github.com/tidwall/gjson v1.9.3 - golang.org/x/net v0.0.0-20210614182718-04defd469f4e - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c + golang.org/x/net v0.0.0-20220722155237-a158d28d115b + golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2 - golang.org/x/tools v0.1.1 + golang.org/x/tools v0.1.12 google.golang.org/genproto v0.0.0-20210513213006-bf773b8c8384 google.golang.org/protobuf v1.28.1 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 51077ba686..a162d0f08a 100644 --- a/go.sum +++ b/go.sum @@ -1,33 +1,39 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= +git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= +github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/brianvoe/gofakeit/v6 v6.16.0/go.mod h1:Ow6qC71xtwm79anlwKRlWZW6zVq9D2XHE4QSSMP/rU8= +github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= -github.com/bytedance/gopkg v0.0.0-20220509134931-d1878f638986/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= -github.com/bytedance/gopkg v0.0.0-20220531084716-665b4f21126f h1:2YCF3cgO6XCub0HIsLrA8ZGhmAPGZfOeSaGjT6Kx4Mc= -github.com/bytedance/gopkg v0.0.0-20220531084716-665b4f21126f/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= +github.com/bytedance/gopkg v0.0.0-20220817015305-b879a72dc90f h1:U3Bk6S9UyqFM5tU3bZ3pwqx5xyypHP7Bm2QCbOUwxSc= +github.com/bytedance/gopkg v0.0.0-20220817015305-b879a72dc90f/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= +github.com/bytedance/mockey v1.2.0 h1:847+X2fBSM4s/AIN4loO5d16PCgEj53j7Q8YVB+8P6c= +github.com/bytedance/mockey v1.2.0/go.mod h1:+Jm/fzWZAuhEDrPXVjDf/jLM2BlLXJkwk94zf2JZ3X4= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/chenzhuoyu/iasm v0.0.0-20220818063314-28c361dae733 h1:Hx6Jxqln+bHRrtjUdgrehhF3gtWVJ2S7bjO/YTNn8Fg= -github.com/chenzhuoyu/iasm v0.0.0-20220818063314-28c361dae733/go.mod h1:wOQ0nsbeOLa2awv8bUYFW/EHXbjQMlZ10fAlXDB2sz8= -github.com/choleraehyq/pid v0.0.13/go.mod h1:uhzeFgxJZWQsZulelVQZwdASxQ9TIPZYL4TPkQMtL/U= -github.com/choleraehyq/pid v0.0.15 h1:PejhUZowqxxssjwyaw4OZURRFjnvftZfeEWK9UoWPXU= -github.com/choleraehyq/pid v0.0.15/go.mod h1:uhzeFgxJZWQsZulelVQZwdASxQ9TIPZYL4TPkQMtL/U= +github.com/chenzhuoyu/iasm v0.0.0-20230222070914-0b1b64b0e762 h1:4+00EOUb1t9uxAbgY8VvgfKJKDpim3co4MqsAbelIbs= +github.com/chenzhuoyu/iasm v0.0.0-20230222070914-0b1b64b0e762/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= +github.com/choleraehyq/pid v0.0.16 h1:1/714sMH9IBlE/aK6xM0acTagGKSzpiR0bDt7l0cG7o= +github.com/choleraehyq/pid v0.0.16/go.mod h1:uhzeFgxJZWQsZulelVQZwdASxQ9TIPZYL4TPkQMtL/U= github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudwego/fastpb v0.0.3 h1:GZE0WzlnjQFE3+vkYFZd964PGT9AXOuvir+JGzuBSPM= -github.com/cloudwego/fastpb v0.0.3/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= -github.com/cloudwego/frugal v0.1.3 h1:tw3+hh4YMmtHFHRue3OGYjAnkxnZRHqeAyG18+7z5aI= -github.com/cloudwego/frugal v0.1.3/go.mod h1:b981ViPYdhI56aFYsoMjl9kv6yeqYSO+iEz2jrhkCgI= -github.com/cloudwego/kitex v0.3.2/go.mod h1:/XD07VpUD9VQWmmoepASgZ6iw//vgWikVA9MpzLC5i0= -github.com/cloudwego/netpoll v0.2.4/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= -github.com/cloudwego/netpoll v0.3.1 h1:xByoORmCLIyKZ8gS+da06WDo3j+jvmhaqS2KeKejtBk= -github.com/cloudwego/netpoll v0.3.1/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= -github.com/cloudwego/thriftgo v0.1.2/go.mod h1:LzeafuLSiHA9JTiWC8TIMIq64iadeObgRUhmVG1OC/w= -github.com/cloudwego/thriftgo v0.2.4 h1:o3JTSygQXaNHmggZYqAkfCBdPGWuKH1Q8XCflCvsSIY= -github.com/cloudwego/thriftgo v0.2.4/go.mod h1:8i9AF5uDdWHGqzUhXDlubCjx4MEfKvWXGQlMWyR0tM4= +github.com/cloudwego/fastpb v0.0.4-0.20230131074846-6fc453d58b96 h1:61PQT0CXNUuQDiDKv/QQ+pFi9uthExZLQz8b5WfS7Qw= +github.com/cloudwego/fastpb v0.0.4-0.20230131074846-6fc453d58b96/go.mod h1:/V13XFTq2TUkxj2qWReV8MwfPC4NnPcy6FsrojnsSG0= +github.com/cloudwego/frugal v0.1.6 h1:aXJ7W0Omion1WTCe4JHAWinQmjXDYzHt03sabu3Rabo= +github.com/cloudwego/frugal v0.1.6/go.mod h1:9ElktKsh5qd2zDBQ5ENhPSQV7F2dZ/mXlr1eaZGDBFs= +github.com/cloudwego/netpoll v0.3.2 h1:/998ICrNMVBo4mlul4j7qcIeY7QnEfuCCPPwck9S3X4= +github.com/cloudwego/netpoll v0.3.2/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/thriftgo v0.2.8 h1:swwp+JQDeL8bBbvzJN3D3J5fluWP+chiUqVPbnToV0I= +github.com/cloudwego/thriftgo v0.2.8/go.mod h1:dAyXHEmKXo0LfMCrblVEY3mUZsdeuA5+i0vF5f09j7E= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -36,7 +42,20 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= +github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= +github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= +github.com/go-fonts/liberation v0.2.0/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= +github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY= +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= +github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= +github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk= +github.com/go-pdf/fpdf v0.5.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= +github.com/go-pdf/fpdf v0.6.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -58,30 +77,36 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 h1:mpL/HvfIgIejhVwAfxBQkwEjlhP5o0O9RAeTAjpwzxc= github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3/go.mod h1:gSuNB+gJaOiQKLEZ+q+PK9Mq3SOzhRcw2GsGS/FhYDk= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/jhump/protoreflect v1.8.2 h1:k2xE7wcUomeqwY0LDCYA16y4WWfyTcMx5mKhk0d4ua0= github.com/jhump/protoreflect v1.8.2/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.1.0/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/mattn/go-isatty v0.0.13/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= @@ -89,11 +114,23 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso= github.com/oleiade/lane v1.0.1 h1:hXofkn7GEOubzTwNpeL9MaNy8WxolCYb9cInAIeqShU= github.com/oleiade/lane v1.0.1/go.mod h1:IyTkraa4maLfjq/GmHR+Dxb4kCMtEGeb+qmhlrQ5Mk4= +github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2dXMnm1mY= +github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= +github.com/phpdave11/gofpdi v1.0.13/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w= +github.com/ruudk/golang-pdf417 v0.0.0-20201230142125-a7e3863a1245/go.mod h1:pQAZKsJ8yyVxGRWYNEm9oFB8ieLgKFnamEyDmSA0BRk= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= @@ -108,22 +145,49 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -golang.org/x/arch v0.0.0-20220722155209-00200b7164a7 h1:VBQqJMNMRfQsWSiCTLgz9XjAfWlgnJAPv8nsp1HF8Tw= -golang.org/x/arch v0.0.0-20220722155209-00200b7164a7/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= +golang.org/x/arch v0.2.0 h1:W1sUEHXiJTfjaFJ5SLo0N6lZn+0eO5gWD1MFeTGqQEY= +golang.org/x/arch v0.2.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20200119044424-58c23975cae1/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20210607152325-775e3b0c77b9/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= +golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= +golang.org/x/image v0.0.0-20220302094943-723b81ca9867/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -133,10 +197,12 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210614182718-04defd469f4e h1:XpT3nA5TvE525Ne3hInMh6+GETgn27Zfm9dxsThnX2Q= -golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -144,38 +210,50 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210304124612-50617c2ba197/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210818153620-00dd8d7831e7/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2 h1:fqTvyMIIj+HRzMmnzr9NtpHP6uVpvB5fkHcgPDC4nu8= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -184,13 +262,23 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200717024301-6ddee64345a6/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.1.1 h1:wGiQel/hW0NnEkJUk8lbzkX2gFJU6PFxf1v5OlCfuOs= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= +gonum.org/v1/gonum v0.9.3/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0= +gonum.org/v1/gonum v0.12.0/go.mod h1:73TDxJfAAHeA8Mk9mf8NlIppyhQNo5GLTcYeqgo2lvY= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY= +gonum.org/v1/plot v0.10.1/go.mod h1:VZW5OlhkL1mysU9vaqNHnsy86inf6Ot+jB3r+BczCEo= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -216,7 +304,6 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= @@ -226,11 +313,11 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/client/option.go b/internal/client/option.go index 629e080319..5cd3007593 100644 --- a/internal/client/option.go +++ b/internal/client/option.go @@ -29,6 +29,7 @@ import ( "github.com/cloudwego/kitex/pkg/discovery" "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/event" + "github.com/cloudwego/kitex/pkg/fallback" "github.com/cloudwego/kitex/pkg/http" "github.com/cloudwego/kitex/pkg/loadbalance" "github.com/cloudwego/kitex/pkg/loadbalance/lbcache" @@ -98,6 +99,9 @@ type Options struct { RetryContainer *retry.Container RetryWithResult *retry.ShouldResultRetry + // fallback policy + Fallback *fallback.Policy + CloseCallbacks []func() error WarmUpOption *warmup.ClientOption @@ -127,6 +131,7 @@ func NewOptions(opts []Option) *Options { o := &Options{ Cli: &rpcinfo.EndpointBasicInfo{Tags: make(map[string]string)}, Svr: &rpcinfo.EndpointBasicInfo{Tags: make(map[string]string)}, + MetaHandlers: []remote.MetaHandler{transmeta.MetainfoClientHandler}, RemoteOpt: newClientRemoteOption(), Configs: rpcinfo.NewRPCConfig(), Locks: NewConfigLocks(), @@ -142,7 +147,6 @@ func NewOptions(opts []Option) *Options { GRPCConnectOpts: new(grpc.ConnectOptions), } o.Apply(opts) - o.MetaHandlers = append(o.MetaHandlers, transmeta.MetainfoClientHandler) o.initRemoteOpt() diff --git a/internal/mocks/README.md b/internal/mocks/README.md new file mode 100644 index 0000000000..3ac9a72578 --- /dev/null +++ b/internal/mocks/README.md @@ -0,0 +1,13 @@ +# Running Prerequisites + +- Run command `go install github.com/golang/mock/mockgen@latest` to install gomock +- Run command `git clone github.com/cloudwego/netpoll` to clone netpoll in the parent directory of kitex. + +# User's Guidance + +- Add a line under the `files` parameter of update.sh. +- Fill in the go file path where the interface you want to mock is located. +- Fill in the path of the output mock file. +- Fill in the package name of the output mock file. +- Run `sh update.sh` to update mock files. +- Now you can use the generated gomock class to mock interface. Refer to https://github.com/golang/mock. \ No newline at end of file diff --git a/internal/mocks/stats/tracer.go b/internal/mocks/stats/tracer.go new file mode 100644 index 0000000000..89a68288cd --- /dev/null +++ b/internal/mocks/stats/tracer.go @@ -0,0 +1,77 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +// Code generated by MockGen. DO NOT EDIT. +// Source: ../../pkg/stats/tracer.go + +// Package tracer is a generated GoMock package. +package stats + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockTracer is a mock of Tracer interface. +type MockTracer struct { + ctrl *gomock.Controller + recorder *MockTracerMockRecorder +} + +// MockTracerMockRecorder is the mock recorder for MockTracer. +type MockTracerMockRecorder struct { + mock *MockTracer +} + +// NewMockTracer creates a new mock instance. +func NewMockTracer(ctrl *gomock.Controller) *MockTracer { + mock := &MockTracer{ctrl: ctrl} + mock.recorder = &MockTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTracer) EXPECT() *MockTracerMockRecorder { + return m.recorder +} + +// Finish mocks base method. +func (m *MockTracer) Finish(ctx context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Finish", ctx) +} + +// Finish indicates an expected call of Finish. +func (mr *MockTracerMockRecorder) Finish(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockTracer)(nil).Finish), ctx) +} + +// Start mocks base method. +func (m *MockTracer) Start(ctx context.Context) context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start", ctx) + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockTracerMockRecorder) Start(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockTracer)(nil).Start), ctx) +} diff --git a/internal/mocks/thrift/test.go b/internal/mocks/thrift/test.go index 7a38dce475..7d2e01213d 100644 --- a/internal/mocks/thrift/test.go +++ b/internal/mocks/thrift/test.go @@ -519,6 +519,10 @@ func (p *MockTestArgs) SetReq(val *MockReq) { p.Req = val } +func (p *MockTestArgs) GetFirstArgument() (interface{}) { + return p.Req +} + var fieldIDToName_MockTestArgs = map[int16]string{ 1: "req", } @@ -687,6 +691,10 @@ func (p *MockTestResult) SetSuccess(x interface{}) { p.Success = x.(*string) } +func (p *MockTestResult) GetResult() interface{} { + return p.Success +} + var fieldIDToName_MockTestResult = map[int16]string{ 0: "success", } diff --git a/internal/mocks/update.sh b/internal/mocks/update.sh index c4909245d0..87fd9706b0 100755 --- a/internal/mocks/update.sh +++ b/internal/mocks/update.sh @@ -5,6 +5,7 @@ cd $(dirname "${BASH_SOURCE[0]}") # source file => output file => package name files=( ../../pkg/limiter/limiter.go limiter/limiter.go limiter +../../pkg/stats/tracer.go stats/tracer.go stats ../../pkg/remote/trans_handler.go remote/trans_handler.go remote ../../pkg/remote/codec.go remote/codec.go remote ../../pkg/remote/connpool.go remote/connpool.go remote @@ -18,6 +19,7 @@ files=( ../../pkg/discovery/discovery.go discovery/discovery.go discovery ../../pkg/loadbalance/loadbalancer.go loadbalance/loadbalancer.go loadbalance ../../pkg/proxy/proxy.go proxy/proxy.go proxy +../../pkg/utils/sharedticker.go utils/sharedticker.go utils ../../../netpoll/connection.go netpoll/connection.go netpoll $GOROOT/src/net/net.go net/net.go net ) @@ -59,6 +61,7 @@ do \ * limitations under the License.\ */\ \ +\ ' $outfile else sed -i '' -e '1i /*\ diff --git a/internal/mocks/utils/sharedticker.go b/internal/mocks/utils/sharedticker.go new file mode 100644 index 0000000000..605c886ee3 --- /dev/null +++ b/internal/mocks/utils/sharedticker.go @@ -0,0 +1,63 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + + +// Code generated by MockGen. DO NOT EDIT. +// Source: ../../pkg/utils/sharedticker.go + +// Package utils is a generated GoMock package. +package utils + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockTickerTask is a mock of TickerTask interface. +type MockTickerTask struct { + ctrl *gomock.Controller + recorder *MockTickerTaskMockRecorder +} + +// MockTickerTaskMockRecorder is the mock recorder for MockTickerTask. +type MockTickerTaskMockRecorder struct { + mock *MockTickerTask +} + +// NewMockTickerTask creates a new mock instance. +func NewMockTickerTask(ctrl *gomock.Controller) *MockTickerTask { + mock := &MockTickerTask{ctrl: ctrl} + mock.recorder = &MockTickerTaskMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTickerTask) EXPECT() *MockTickerTaskMockRecorder { + return m.recorder +} + +// Tick mocks base method. +func (m *MockTickerTask) Tick() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Tick") +} + +// Tick indicates an expected call of Tick. +func (mr *MockTickerTaskMockRecorder) Tick() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tick", reflect.TypeOf((*MockTickerTask)(nil).Tick)) +} diff --git a/internal/server/option.go b/internal/server/option.go index 7b9ea1f360..85085279ed 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -109,6 +109,7 @@ func NewOptions(opts []Option) *Options { Svr: &rpcinfo.EndpointBasicInfo{}, Configs: rpcinfo.NewRPCConfig(), Once: configutil.NewOptionOnce(), + MetaHandlers: []remote.MetaHandler{transmeta.MetainfoServerHandler}, RemoteOpt: newServerRemoteOption(), DebugService: diagnosis.NoopService, ExitSignal: DefaultSysExitSignal, @@ -122,7 +123,6 @@ func NewOptions(opts []Option) *Options { Registry: registry.NoopRegistry, } ApplyOptions(opts, o) - o.MetaHandlers = append(o.MetaHandlers, transmeta.MetainfoServerHandler) rpcinfo.AsMutableRPCConfig(o.Configs).LockConfig(o.LockBits) if o.StatsLevel == nil { level := stats.LevelDisabled diff --git a/internal/server/remote_option.go b/internal/server/remote_option.go index f93ea20ae7..a814df152c 100644 --- a/internal/server/remote_option.go +++ b/internal/server/remote_option.go @@ -25,13 +25,14 @@ import ( "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/trans/detection" "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" ) func newServerRemoteOption() *remote.ServerOption { return &remote.ServerOption{ TransServerFactory: netpoll.NewTransServerFactory(), - SvrHandlerFactory: detection.NewSvrTransHandlerFactory(), + SvrHandlerFactory: detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()), Codec: codec.NewDefaultCodec(), Address: defaultAddress, ExitWaitTime: defaultExitWaitTime, diff --git a/internal/server/remote_option_windows.go b/internal/server/remote_option_windows.go index 7ef1c746b5..79fced6f97 100644 --- a/internal/server/remote_option_windows.go +++ b/internal/server/remote_option_windows.go @@ -23,14 +23,16 @@ package server import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/trans/detection" "github.com/cloudwego/kitex/pkg/remote/trans/gonet" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" ) func newServerRemoteOption() *remote.ServerOption { return &remote.ServerOption{ TransServerFactory: gonet.NewTransServerFactory(), - SvrHandlerFactory: gonet.NewSvrTransHandlerFactory(), + SvrHandlerFactory: detection.NewSvrTransHandlerFactory(gonet.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()), Codec: codec.NewDefaultCodec(), Address: defaultAddress, ExitWaitTime: defaultExitWaitTime, diff --git a/internal/stats/stats_util.go b/internal/stats/stats_util.go index 39dbcab251..0e1d70858f 100644 --- a/internal/stats/stats_util.go +++ b/internal/stats/stats_util.go @@ -25,7 +25,7 @@ import ( // Record records the event to RPCStats. func Record(ctx context.Context, ri rpcinfo.RPCInfo, event stats.Event, err error) { - if ctx == nil { + if ctx == nil || ri.Stats() == nil { return } if err != nil { diff --git a/pkg/circuitbreak/cbsuite.go b/pkg/circuitbreak/cbsuite.go index 76e831cb1e..3b75cefd0f 100644 --- a/pkg/circuitbreak/cbsuite.go +++ b/pkg/circuitbreak/cbsuite.go @@ -46,12 +46,35 @@ func GetDefaultCBConfig() CBConfig { } // CBConfig is policy config of CircuitBreaker. +// DON'T FORGET to update DeepCopy() and Equals() if you add new fields. type CBConfig struct { Enable bool `json:"enable"` ErrRate float64 `json:"err_rate"` MinSample int64 `json:"min_sample"` } +// DeepCopy returns a full copy of CBConfig. +func (c *CBConfig) DeepCopy() *CBConfig { + if c == nil { + return nil + } + return &CBConfig{ + Enable: c.Enable, + ErrRate: c.ErrRate, + MinSample: c.MinSample, + } +} + +func (c *CBConfig) Equals(other *CBConfig) bool { + if c == nil && other == nil { + return true + } + if c == nil || other == nil { + return false + } + return c.Enable == other.Enable && c.ErrRate == other.ErrRate && c.MinSample == other.MinSample +} + // GenServiceCBKeyFunc to generate circuit breaker key through rpcinfo. // You can customize the config key according to your config center. type GenServiceCBKeyFunc func(ri rpcinfo.RPCInfo) string diff --git a/pkg/circuitbreak/cbsuite_test.go b/pkg/circuitbreak/cbsuite_test.go index 22307e0296..09a05d4920 100644 --- a/pkg/circuitbreak/cbsuite_test.go +++ b/pkg/circuitbreak/cbsuite_test.go @@ -20,6 +20,7 @@ import ( "context" "errors" "net" + "reflect" "testing" "time" @@ -430,3 +431,195 @@ func (m mockInst) Weight() int { func (m mockInst) Tag(key string) (value string, exist bool) { return } + +func TestCBConfig_DeepCopy(t *testing.T) { + type fields struct { + c *CBConfig + } + tests := []struct { + name string + fields fields + want *CBConfig + }{ + { + name: "test_nil_copy", + fields: fields{ + c: nil, + }, + want: nil, + }, + { + name: "test_all_copy", + fields: fields{ + c: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + want: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.fields.c.DeepCopy(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("DeepCopy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCBConfig_Equals(t *testing.T) { + type fields struct { + c *CBConfig + } + type args struct { + other *CBConfig + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "test_nil_equal", + fields: fields{ + c: nil, + }, + args: args{ + other: nil, + }, + want: true, + }, + { + name: "test_nil_not_equal", + fields: fields{ + c: nil, + }, + args: args{ + other: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + want: false, + }, + { + name: "test_other_nil_not_equal", + fields: fields{ + c: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + args: args{ + other: nil, + }, + want: false, + }, + { + name: "test_all_equal", + fields: fields{ + c: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + args: args{ + other: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + want: true, + }, + { + name: "test_all_not_equal", + fields: fields{ + c: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + args: args{ + other: &CBConfig{ + Enable: false, + ErrRate: 0.2, + MinSample: 20, + }, + }, + want: false, + }, + { + name: "test_enable_equal", + fields: fields{ + c: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + args: args{ + other: &CBConfig{ + Enable: true, + ErrRate: 0.2, + MinSample: 20, + }, + }, + want: false, + }, + { + name: "test_err_rate_equal", + fields: fields{ + c: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + args: args{ + other: &CBConfig{ + Enable: false, + ErrRate: 0.1, + MinSample: 20, + }, + }, + want: false, + }, + { + name: "test_min_sample_equal", + fields: fields{ + c: &CBConfig{ + Enable: true, + ErrRate: 0.1, + MinSample: 10, + }, + }, + args: args{ + other: &CBConfig{ + Enable: false, + ErrRate: 0.2, + MinSample: 10, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.fields.c.Equals(tt.args.other); got != tt.want { + t.Errorf("Equals() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/event/queue.go b/pkg/event/queue.go index ad31070301..be38937922 100644 --- a/pkg/event/queue.go +++ b/pkg/event/queue.go @@ -18,7 +18,6 @@ package event import ( "sync" - "sync/atomic" ) const ( @@ -55,28 +54,22 @@ func NewQueue(cap int) Queue { // Push pushes an event to the queue. func (q *queue) Push(e *Event) { - for { - old := atomic.LoadUint32(&q.tail) - new := old + 1 - if new >= uint32(len(q.ring)) { - new = 0 - } - oldV := atomic.LoadUint32(q.tailVersion[old]) - newV := oldV + 1 - if atomic.CompareAndSwapUint32(&q.tail, old, new) && atomic.CompareAndSwapUint32(q.tailVersion[old], oldV, newV) { - q.mu.Lock() - q.ring[old] = e - q.mu.Unlock() - break - } - } + q.mu.Lock() + defer q.mu.Unlock() + + q.ring[q.tail] = e + + newVersion := (*(q.tailVersion[q.tail])) + 1 + q.tailVersion[q.tail] = &newVersion + + q.tail = (q.tail + 1) % uint32(len(q.ring)) } // Dump dumps the previously pushed events out in a reversed order. func (q *queue) Dump() interface{} { - results := make([]*Event, 0, len(q.ring)) q.mu.RLock() defer q.mu.RUnlock() + results := make([]*Event, 0, len(q.ring)) pos := int32(q.tail) for i := 0; i < len(q.ring); i++ { pos-- diff --git a/pkg/fallback/fallback.go b/pkg/fallback/fallback.go new file mode 100644 index 0000000000..177609a288 --- /dev/null +++ b/pkg/fallback/fallback.go @@ -0,0 +1,138 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package fallback ... +package fallback + +import ( + "context" + "errors" + "reflect" + + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/utils" +) + +// ErrorFallback is to build fallback policy for error. +func ErrorFallback(ef Func) *Policy { + return &Policy{fallbackFunc: func(ctx context.Context, req utils.KitexArgs, resp utils.KitexResult, err error) (fbErr error) { + if err == nil { + return err + } + return ef(ctx, req, resp, err) + }} +} + +// TimeoutAndCBFallback is to build fallback policy for rpc timeout and circuit breaker error. +// Kitex will filter the errors, only timeout and circuit breaker can trigger the ErrorFunc to execute. +func TimeoutAndCBFallback(ef Func) *Policy { + return &Policy{fallbackFunc: func(ctx context.Context, req utils.KitexArgs, resp utils.KitexResult, err error) (fbErr error) { + if err == nil { + return err + } + if kerrors.IsTimeoutError(err) || errors.Is(err, kerrors.ErrCircuitBreak) { + return ef(ctx, req, resp, err) + } + return err + }} +} + +// NewFallbackPolicy is to build a fallback policy. +func NewFallbackPolicy(fb Func) *Policy { + return &Policy{ + fallbackFunc: fb, + } +} + +// Policy is the definition for fallback. +// - fallbackFunc is fallback func. +// - reportAsFallback is used to decide whether to report Metric according to the Fallback result. +type Policy struct { + fallbackFunc Func + reportAsFallback bool +} + +func (p *Policy) EnableReportAsFallback() *Policy { + p.reportAsFallback = true + return p +} + +// IsPolicyValid to check if the Fallback policy is valid. +func IsPolicyValid(p *Policy) bool { + return p != nil && p.fallbackFunc != nil +} + +// UnwrapHelper helps to get the real request and response. +// Therefor, the RealReqRespFunc only need to process the real rpc req and resp but not the XXXArgs and XXXResult. +// eg: +// +// `client.WithFallback(fallback.NewFallbackPolicy(fallback.UnwrapHelper(yourRealReqRespFunc)))` +func UnwrapHelper(userFB RealReqRespFunc) Func { + return func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) error { + req, resp := args.GetFirstArgument(), result.GetResult() + fbResp, fbErr := userFB(ctx, req, resp, err) + if fbResp != nil { + result.SetSuccess(fbResp) + } + return fbErr + } +} + +// Func is the definition for fallback func, which can do fallback both for error and resp. +// Notice !! The args and result are not the real rpc req and resp, are respectively XXXArgs and XXXResult of generated code. +// setup eg: client.WithFallback(fallback.NewFallbackPolicy(yourFunc)) +type Func func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) + +// RealReqRespFunc is the definition for fallback func with real rpc req as param, and must return the real rpc resp. +// setup eg: client.WithFallback(fallback.TimeoutAndCBFallback(fallback.UnwrapHelper(yourRealReqRespFunc))) +type RealReqRespFunc func(ctx context.Context, req, resp interface{}, err error) (fbResp interface{}, fbErr error) + +// DoIfNeeded do fallback +func (p *Policy) DoIfNeeded(ctx context.Context, ri rpcinfo.RPCInfo, args, result interface{}, err error) (fbErr error, reportAsFallback bool) { + if p == nil { + return err, false + } + ka, kaOK := args.(utils.KitexArgs) + kr, krOK := result.(utils.KitexResult) + if !kaOK || !krOK { + klog.Warn("KITEX: fallback cannot be supported, the args and result must be KitexArgs and KitexResult") + return err, false + } + err, allowReportAsFB := getBizErrIfExist(ri, err) + reportAsFallback = allowReportAsFB && p.reportAsFallback + + fbErr = p.fallbackFunc(ctx, ka, kr, err) + + if fbErr == nil && reflect.ValueOf(kr.GetResult()).IsNil() { + klog.Warn("KITEX: both fallback resp and error are nil, return original err") + return err, false + } + return fbErr, reportAsFallback +} + +func getBizErrIfExist(ri rpcinfo.RPCInfo, err error) (error, bool) { + if err == nil { + if bizErr := ri.Invocation().BizStatusErr(); bizErr != nil { + // biz error also as error passed to fallback + err = bizErr + } + // if err is nil, reportAsFallback always be false even if user set true + return err, false + } + return err, true +} diff --git a/pkg/fallback/fallback_test.go b/pkg/fallback/fallback_test.go new file mode 100644 index 0000000000..968ba4927c --- /dev/null +++ b/pkg/fallback/fallback_test.go @@ -0,0 +1,178 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fallback + +import ( + "context" + "errors" + "testing" + + mockthrift "github.com/cloudwego/kitex/internal/mocks/thrift" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" + "github.com/cloudwego/kitex/pkg/utils" +) + +func TestNewFallbackPolicy(t *testing.T) { + // case0: policy is nil + var fbP *Policy + result := mockthrift.NewMockTestResult() + fbErr, reportAsFallback := fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), result, errMock) + test.Assert(t, fbErr == errMock) + test.Assert(t, !reportAsFallback) + + // case1: original error is non-nil and set result in fallback + fbP = NewFallbackPolicy(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + _, ok := args.(*mockthrift.MockTestArgs) + test.Assert(t, ok) + result.SetSuccess(&retStr) + return nil + }) + test.Assert(t, !fbP.reportAsFallback) + test.Assert(t, fbP.fallbackFunc != nil) + result = mockthrift.NewMockTestResult() + fbErr, reportAsFallback = fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), result, errMock) + test.Assert(t, *result.GetResult().(*string) == retStr) + test.Assert(t, fbErr == nil, fbErr) + test.Assert(t, !reportAsFallback) + + // case2: enable reportAsFallback + fbP = NewFallbackPolicy(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + _, ok := args.(*mockthrift.MockTestArgs) + test.Assert(t, ok) + result.SetSuccess(&retStr) + return nil + }).EnableReportAsFallback() + test.Assert(t, fbP.reportAsFallback) + _, reportAsFallback = fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), mockthrift.NewMockTestResult(), errMock) + test.Assert(t, reportAsFallback) + + // case3: original error is nil, still can update result + fbP = NewFallbackPolicy(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + _, ok := args.(*mockthrift.MockTestArgs) + test.Assert(t, ok) + result.SetSuccess(&retStr) + return nil + }).EnableReportAsFallback() + test.Assert(t, fbP.reportAsFallback) + result = mockthrift.NewMockTestResult() + fbErr, reportAsFallback = fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), result, nil) + test.Assert(t, *result.GetResult().(*string) == retStr) + test.Assert(t, fbErr == nil) + test.Assert(t, !reportAsFallback) + + // case4: fallback return nil-resp and nil-err, framework will return the original resp and err + fbP = NewFallbackPolicy(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + _, ok := args.(*mockthrift.MockTestArgs) + test.Assert(t, ok) + return + }).EnableReportAsFallback() + test.Assert(t, fbP.reportAsFallback) + fbErr, reportAsFallback = fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), mockthrift.NewMockTestResult(), errMock) + test.Assert(t, fbErr == errMock) + test.Assert(t, !reportAsFallback) + + // case5: WithRealReqResp, original error is non-nil and set result in fallback + fbP = NewFallbackPolicy(UnwrapHelper(func(ctx context.Context, req, resp interface{}, err error) (fbResp interface{}, fbErr error) { + _, ok := req.(*mockthrift.MockReq) + test.Assert(t, ok, req) + return &retStr, nil + })) + test.Assert(t, !fbP.reportAsFallback) + test.Assert(t, fbP.fallbackFunc != nil) + args := mockthrift.NewMockTestArgs() + args.Req = &mockthrift.MockReq{} + result = mockthrift.NewMockTestResult() + fbErr, reportAsFallback = fbP.DoIfNeeded(context.Background(), genRPCInfo(), args, result, errMock) + test.Assert(t, *result.GetResult().(*string) == retStr) + test.Assert(t, fbErr == nil) + test.Assert(t, !reportAsFallback) +} + +func TestErrorFallback(t *testing.T) { + // case1: err is non-nil + fbP := ErrorFallback(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + result.SetSuccess(&retStr) + return nil + }) + result := mockthrift.NewMockTestResult() + fbErr, reportAsFallback := fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), result, errMock) + test.Assert(t, *result.GetResult().(*string) == retStr) + test.Assert(t, fbErr == nil) + test.Assert(t, !reportAsFallback) + + // case 2: err is nil, then the fallback func won't be executed + fbExecuted := false + fbP = ErrorFallback(UnwrapHelper(func(ctx context.Context, req, resp interface{}, err error) (fbResp interface{}, fbErr error) { + fbExecuted = true + return &retStr, nil + })).EnableReportAsFallback() + result = mockthrift.NewMockTestResult() + fbErr, reportAsFallback = fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), result, nil) + test.Assert(t, fbErr == nil) + test.Assert(t, !reportAsFallback) + test.Assert(t, !fbExecuted) +} + +func TestTimeoutAndCBFallback(t *testing.T) { + // case1: rpc timeout will do fallback + fbP := TimeoutAndCBFallback(UnwrapHelper(func(ctx context.Context, req, resp interface{}, err error) (fbResp interface{}, fbErr error) { + _, ok := req.(*mockthrift.MockReq) + test.Assert(t, ok) + return &retStr, nil + })) + result := mockthrift.NewMockTestResult() + fbErr, reportAsFallback := fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), result, kerrors.ErrRPCTimeout.WithCause(errMock)) + test.Assert(t, *result.GetResult().(*string) == retStr) + test.Assert(t, fbErr == nil) + test.Assert(t, !reportAsFallback) + + // case2: circuit breaker error will do fallback + fbP = TimeoutAndCBFallback(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + result.SetSuccess(&retStr) + return nil + }).EnableReportAsFallback() + result = mockthrift.NewMockTestResult() + fbErr, reportAsFallback = fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), result, kerrors.ErrCircuitBreak.WithCause(errMock)) + test.Assert(t, *result.GetResult().(*string) == retStr) + test.Assert(t, fbErr == nil) + test.Assert(t, reportAsFallback) + + // case3: err is non-nil, but not rpc timeout or circuit breaker + fbP = TimeoutAndCBFallback(func(ctx context.Context, args utils.KitexArgs, result utils.KitexResult, err error) (fbErr error) { + result.SetSuccess(&retStr) + return nil + }) + result = mockthrift.NewMockTestResult() + fbErr, reportAsFallback = fbP.DoIfNeeded(context.Background(), genRPCInfo(), mockthrift.NewMockTestArgs(), result, errMock) + test.Assert(t, fbErr == errMock) + test.Assert(t, !reportAsFallback) +} + +var ( + method = "test" + errMock = errors.New("mock") + retStr = "success" +) + +func genRPCInfo() rpcinfo.RPCInfo { + to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{Method: method}, method).ImmutableView() + ri := rpcinfo.NewRPCInfo(to, to, rpcinfo.NewInvocation("", method), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + return ri +} diff --git a/pkg/generic/descriptor/router.go b/pkg/generic/descriptor/router.go index 66821b971a..c1a37084b3 100644 --- a/pkg/generic/descriptor/router.go +++ b/pkg/generic/descriptor/router.go @@ -99,7 +99,7 @@ func (r *router) Lookup(req *HTTPRequest) (*FunctionDescriptor, error) { if !ok { return nil, fmt.Errorf("function lookup failed, no root with method=%s", req.Method) } - fn, ps, _ := root.getValue(req.Path, r.getParams) + fn, ps, _ := root.getValue(req.Path, r.getParams, false) if fn == nil { r.putParams(ps) return nil, fmt.Errorf("function lookup failed, path=%s", req.Path) diff --git a/pkg/generic/descriptor/tree.go b/pkg/generic/descriptor/tree.go index 23d1fa0213..b1186fd203 100644 --- a/pkg/generic/descriptor/tree.go +++ b/pkg/generic/descriptor/tree.go @@ -10,50 +10,11 @@ package descriptor import ( + "fmt" + "net/url" "strings" ) -func min(a, b int) int { - if a <= b { - return a - } - return b -} - -func longestCommonPrefix(a, b string) int { - i := 0 - max := min(len(a), len(b)) - for i < max && a[i] == b[i] { - i++ - } - return i -} - -// Search for a wildcard segment and check the name for invalid characters. -// Returns -1 as index, if no wildcard was found. -func findWildcard(path string) (wildcard string, i int, valid bool) { - // Find start - for start, c := range []byte(path) { - // A wildcard starts with ':' (param) or '*' (catch-all) - if c != ':' && c != '*' { - continue - } - - // Find end and check for invalid characters - valid = true - for end, c := range []byte(path[start+1:]) { - switch c { - case '/': - return path[start : start+1+end], start, valid - case ':', '*': - valid = false - } - } - return path[start:], start, valid - } - return "", -1, false -} - func countParams(path string) uint16 { var n uint for i := range []byte(path) { @@ -69,249 +30,227 @@ type nodeType uint8 const ( static nodeType = iota // default - root param catchAll + paramLabel = byte(':') + anyLabel = byte('*') + slash = "/" + nilString = "" ) -type node struct { - path string - indices string - wildChild bool - nType nodeType - priority uint32 - children []*node - function *FunctionDescriptor -} - -// Increments priority of the given child and reorders if necessary -func (n *node) incrementChildPrio(pos int) int { - cs := n.children - cs[pos].priority++ - prio := cs[pos].priority - - // Adjust position (move to front) - newPos := pos - for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- { - // Swap node positions - cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1] +type ( + node struct { + nType nodeType + label byte + prefix string + parent *node + children children + // original path + ppath string + // param names + pnames []string + function *FunctionDescriptor + paramChild *node + anyChild *node + // isLeaf indicates that node does not have child routes + isLeaf bool } + children []*node +) - // Build new index char string - if newPos != pos { - n.indices = n.indices[:newPos] + // Unchanged prefix, might be empty - n.indices[pos:pos+1] + // The index char we move - n.indices[newPos:pos] + n.indices[pos+1:] // Rest without char at 'pos' +func checkPathValid(path string) { + if path == nilString { + panic("empty path") + } + if path[0] != '/' { + panic("path must begin with '/'") + } + for i, c := range []byte(path) { + switch c { + case ':': + if (i < len(path)-1 && path[i+1] == '/') || i == (len(path)-1) { + panic("wildcards must be named with a non-empty name in path '" + path + "'") + } + i++ + for ; i < len(path) && path[i] != '/'; i++ { + if path[i] == ':' || path[i] == '*' { + panic("only one wildcard per path segment is allowed, find multi in path '" + path + "'") + } + } + case '*': + if i == len(path)-1 { + panic("wildcards must be named with a non-empty name in path '" + path + "'") + } + if i > 0 && path[i-1] != '/' { + panic(" no / before wildcards in path " + path) + } + for ; i < len(path); i++ { + if path[i] == '/' { + panic("catch-all routes are only allowed at the end of the path in path '" + path + "'") + } + } + } } - - return newPos } // addRoute adds a node with the given function to the path. // Not concurrency-safe! func (n *node) addRoute(path string, function *FunctionDescriptor) { - fullPath := path - n.priority++ - - // Empty tree - if n.path == "" && n.indices == "" { - n.insertChild(path, fullPath, function) - n.nType = root - return - } + checkPathValid(path) -walk: - for { - // Find the longest common prefix. - // This also implies that the common prefix contains no ':' or '*' - // since the existing key can't contain those chars. - i := longestCommonPrefix(path, n.path) - - // Split edge - if i < len(n.path) { - child := node{ - path: n.path[i:], - wildChild: n.wildChild, - nType: static, - indices: n.indices, - children: n.children, - function: n.function, - priority: n.priority - 1, - } - - n.children = []*node{&child} - // []byte for proper unicode char conversion, see #65 - n.indices = string([]byte{n.path[i]}) - n.path = path[:i] - n.function = nil - n.wildChild = false - } + var ( + pnames []string // Param names + ppath = path // Pristine path + ) - // Make new node a child of this node - if i < len(path) { - path = path[i:] - - if n.wildChild { - n = n.children[0] - n.priority++ - - // Check if the wildcard matches - if len(path) >= len(n.path) && n.path == path[:len(n.path)] && - // Adding a child to a catchAll is not possible - n.nType != catchAll && - // Check for longer wildcard, e.g. :name and :names - (len(n.path) >= len(path) || path[len(n.path)] == '/') { - continue walk - } else { - // Wildcard conflict - pathSeg := path - if n.nType != catchAll { - pathSeg = strings.SplitN(pathSeg, "/", 2)[0] - } - prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path - panic("'" + pathSeg + - "' in new path '" + fullPath + - "' conflicts with existing wildcard '" + n.path + - "' in existing prefix '" + prefix + - "'") - } - } + if function == nil { + panic(fmt.Sprintf("adding route without handler function: %v", path)) + } - idxc := path[0] + // Add the front static route part of a non-static route + for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { + // param route + if path[i] == paramLabel { + j := i + 1 - // '/' after param - if n.nType == param && idxc == '/' && len(n.children) == 1 { - n = n.children[0] - n.priority++ - continue walk + n.insert(path[:i], nil, static, nilString, nil) + for ; i < lcpIndex && path[i] != '/'; i++ { } - // Check if a child with the next path byte exists - for i, c := range []byte(n.indices) { - if c == idxc { - i = n.incrementChildPrio(i) - n = n.children[i] - continue walk - } - } + pnames = append(pnames, path[j:i]) + path = path[:j] + path[i:] + i, lcpIndex = j, len(path) - // Otherwise insert it - if idxc != ':' && idxc != '*' { - // []byte for proper unicode char conversion, see #65 - n.indices += string([]byte{idxc}) - child := &node{} - n.children = append(n.children, child) - n.incrementChildPrio(len(n.indices) - 1) - n = child + if i == lcpIndex { + // path node is last fragment of route path. ie. `/users/:id` + n.insert(path[:i], function, param, ppath, pnames) + return + } else { + n.insert(path[:i], nil, param, nilString, pnames) } - n.insertChild(path, fullPath, function) + } else if path[i] == anyLabel { + n.insert(path[:i], nil, static, nilString, nil) + pnames = append(pnames, path[i+1:]) + n.insert(path[:i+1], function, catchAll, ppath, pnames) return } - - // Otherwise add function to current node - if n.function != nil { - panic("a function is already registered for path '" + fullPath + "'") - } - n.function = function - return } + n.insert(path, function, static, ppath, pnames) } -func (n *node) insertChild(path, fullPath string, function *FunctionDescriptor) { +func (n *node) insert(path string, function *FunctionDescriptor, t nodeType, ppath string, pnames []string) { + currentNode := n + search := path + for { - // Find prefix until first wildcard - wildcard, i, valid := findWildcard(path) - if i < 0 { // No wildcard found - break - } + searchLen := len(search) + prefixLen := len(currentNode.prefix) + lcpLen := 0 - // The wildcard name must not contain ':' and '*' - if !valid { - panic("only one wildcard per path segment is allowed, has: '" + - wildcard + "' in path '" + fullPath + "'") + max := prefixLen + if searchLen < max { + max = searchLen } - - // Check if the wildcard has a name - if len(wildcard) < 2 { - panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") + for ; lcpLen < max && search[lcpLen] == currentNode.prefix[lcpLen]; lcpLen++ { } - // param - if wildcard[0] == ':' { - if i > 0 { - // Insert prefix before the current wildcard - n.path = path[:i] - path = path[i:] + if lcpLen == 0 { + currentNode.label = search[0] + currentNode.prefix = search + if function != nil { + currentNode.nType = t + currentNode.function = function + currentNode.ppath = ppath + currentNode.pnames = pnames + } + currentNode.isLeaf = currentNode.children == nil && currentNode.paramChild == nil && currentNode.anyChild == nil + } else if lcpLen < prefixLen { + // Split node + n := newNode( + currentNode.nType, + currentNode.prefix[lcpLen:], + currentNode, + currentNode.children, + currentNode.function, + currentNode.ppath, + currentNode.pnames, + currentNode.paramChild, + currentNode.anyChild, + ) + // Update parent path for all children to new node + for _, child := range currentNode.children { + child.parent = n + } + if currentNode.paramChild != nil { + currentNode.paramChild.parent = n + } + if currentNode.anyChild != nil { + currentNode.anyChild.parent = n } - n.wildChild = true - child := &node{ - nType: param, - path: wildcard, + // Reset parent node + currentNode.nType = static + currentNode.label = currentNode.prefix[0] + currentNode.prefix = currentNode.prefix[:lcpLen] + currentNode.children = nil + currentNode.function = nil + currentNode.ppath = nilString + currentNode.pnames = nil + currentNode.paramChild = nil + currentNode.anyChild = nil + currentNode.isLeaf = false + + // Only Static children could reach here + currentNode.children = append(currentNode.children, n) + + if lcpLen == searchLen { + // At parent node + currentNode.nType = t + currentNode.function = function + currentNode.ppath = ppath + currentNode.pnames = pnames + } else { + // Create child node + n = newNode(t, search[lcpLen:], currentNode, nil, function, ppath, pnames, nil, nil) + // Only Static children could reach here + currentNode.children = append(currentNode.children, n) } - n.children = []*node{child} - n = child - n.priority++ - - // If the path doesn't end with the wildcard, then there - // will be another non-wildcard subpath starting with '/' - if len(wildcard) < len(path) { - path = path[len(wildcard):] - child := &node{ - priority: 1, - } - n.children = []*node{child} - n = child + currentNode.isLeaf = currentNode.children == nil && currentNode.paramChild == nil && currentNode.anyChild == nil + } else if lcpLen < searchLen { + search = search[lcpLen:] + c := currentNode.findChildWithLabel(search[0]) + if c != nil { + // Go deeper + currentNode = c continue } + // Create child node + n := newNode(t, search, currentNode, nil, function, ppath, pnames, nil, nil) + switch t { + case static: + currentNode.children = append(currentNode.children, n) + case param: + currentNode.paramChild = n + case catchAll: + currentNode.anyChild = n + } + currentNode.isLeaf = currentNode.children == nil && currentNode.paramChild == nil && currentNode.anyChild == nil + } else { + // Node already exists + if currentNode.function != nil && function != nil { + panic("handlers are already registered for path '" + ppath + "'") + } - // Otherwise we're done. Insert the function in the new leaf - n.function = function - return - } - - // catchAll - if i+len(wildcard) != len(path) { - panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") - } - - if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { - panic("catch-all conflicts with existing function for the path segment root in path '" + fullPath + "'") - } - - // Currently fixed width 1 for '/' - i-- - if path[i] != '/' { - panic("no / before catch-all in path '" + fullPath + "'") - } - - n.path = path[:i] - - // First node: catchAll node with empty path - child := &node{ - wildChild: true, - nType: catchAll, - } - n.children = []*node{child} - n.indices = string('/') - n = child - n.priority++ - - // Second node: node holding the variable - child = &node{ - path: path[i:], - nType: catchAll, - function: function, - priority: 1, + if function != nil { + currentNode.function = function + currentNode.ppath = ppath + if len(currentNode.pnames) == 0 { + currentNode.pnames = pnames + } + } } - n.children = []*node{child} - return } - - // If no wildcard was found, simply insert the path and function - n.path = path - n.function = function } // Returns the function registered with the given path (key). The values of @@ -319,136 +258,206 @@ func (n *node) insertChild(path, fullPath string, function *FunctionDescriptor) // If no function can be found, a TSR (trailing slash redirect) recommendation is // made if a function exists with an extra (without the) trailing slash for the // given path. -func (n *node) getValue(path string, params func() *Params) (function *FunctionDescriptor, ps *Params, tsr bool) { -walk: // Outer loop for walking the tree +func (n *node) getValue(path string, params func() *Params, unescape bool) (function *FunctionDescriptor, ps *Params, tsr bool) { + var ( + cn = n // current node + search = path // current path + searchIndex = 0 + paramIndex int + ) + + backtrackToNextNodeType := func(fromNodeType nodeType) (nextNodeType nodeType, valid bool) { + previous := cn + cn = previous.parent + valid = cn != nil + + // Next node type by priority + if previous.nType == catchAll { + nextNodeType = static + } else { + nextNodeType = previous.nType + 1 + } + + if fromNodeType == static { + // when backtracking is done from static type block we did not change search so nothing to restore + return + } + + // restore search to value it was before we move to current node we are backtracking from. + if previous.nType == static { + searchIndex -= len(previous.prefix) + } else { + paramIndex-- + // for param/any node.prefix value is always `:`/`*` so we cannot deduce searchIndex from that and must use pValue + // for that index as it would also contain part of path we cut off before moving into node we are backtracking from + searchIndex -= len(ps.params[paramIndex].Value) + ps.params = ps.params[:paramIndex] + } + search = path[searchIndex:] + return + } + + // search order: static > param > any for { - prefix := n.path - if len(path) > len(prefix) { - if path[:len(prefix)] == prefix { - path = path[len(prefix):] - - // If this node does not have a wildcard (param or catchAll) - // child, we can just look up the next child node and continue - // to walk down the tree - if !n.wildChild { - idxc := path[0] - for i, c := range []byte(n.indices) { - if c == idxc { - n = n.children[i] - continue walk - } - } - - // Nothing found. - // We can recommend to redirect to the same URL without a - // trailing slash if a leaf exists for that path. - tsr = (path == "/" && n.function != nil) - return + if cn.nType == static { + if len(search) >= len(cn.prefix) && cn.prefix == search[:len(cn.prefix)] { + // Continue search + search = search[len(cn.prefix):] + searchIndex = searchIndex + len(cn.prefix) + } else { + // not equal + if (len(cn.prefix) == len(search)+1) && + (cn.prefix[len(search)]) == '/' && cn.prefix[:len(search)] == search && (cn.function != nil || cn.anyChild != nil) { + tsr = true } - - // Handle wildcard child - n = n.children[0] - switch n.nType { - case param: - // Find param end (either '/' or path end) - end := 0 - for end < len(path) && path[end] != '/' { - end++ - } - - // Save param value - if params != nil { - if ps == nil { - ps = params() - } - // Expand slice within preallocated capacity - i := len(ps.params) - ps.params = ps.params[:i+1] - ps.params[i] = Param{ - Key: n.path[1:], - Value: path[:end], - } - } - - // We need to go deeper! - if end < len(path) { - if len(n.children) > 0 { - path = path[end:] - n = n.children[0] - continue walk - } - - // ... but we can't - tsr = (len(path) == end+1) - return - } - - if function = n.function; function != nil { - return - } else if len(n.children) == 1 { - // No function found. Check if a function for this path + a - // trailing slash exists for TSR recommendation - n = n.children[0] - tsr = (n.path == "/" && n.function != nil) || (n.path == "" && n.indices == "/") - } - - return - - case catchAll: - // Save param value - if params != nil { - if ps == nil { - ps = params() - } - // Expand slice within preallocated capacity - i := len(ps.params) - ps.params = ps.params[:i+1] - ps.params[i] = Param{ - Key: n.path[2:], - Value: path, - } - } - - function = n.function - return - - default: - panic("invalid node type") + // No matching prefix, let's backtrack to the first possible alternative node of the decision path + nk, ok := backtrackToNextNodeType(static) + if !ok { + return // No other possibilities on the decision path + } else if nk == param { + goto Param + } else { + // Not found (this should never be possible for static node we are looking currently) + break } } - } else if path == prefix { - // We should have reached the node containing the function. - // Check if this node has a function registered. - if function = n.function; function != nil { - return + } + if search == nilString && cn.function != nil { + function = cn.function + break + } + + // Static node + if search != nilString { + // If it can execute that logic, there is handler registered on the current node and search is `/`. + if search == "/" && cn.function != nil { + tsr = true + } + if child := cn.findChild(search[0]); child != nil { + cn = child + continue } + } - // If there is no function for this route, but this route has a - // wildcard child, there must be a function for this path with an - // additional trailing slash - if path == "/" && n.wildChild && n.nType != root { + if search == nilString { + if cd := cn.findChild('/'); cd != nil && (cd.function != nil || cd.anyChild != nil) { tsr = true - return } + } - // No function found. Check if a function for this path + a - // trailing slash exists for trailing slash recommendation - for i, c := range []byte(n.indices) { - if c == '/' { - n = n.children[i] - tsr = (len(n.path) == 1 && n.function != nil) || - (n.nType == catchAll && n.children[0].function != nil) - return + Param: + // Param node + if child := cn.paramChild; search != nilString && child != nil { + cn = child + i := strings.Index(search, slash) + if i == -1 { + i = len(search) + } + if ps == nil { + ps = params() + } + val := search[:i] + if unescape { + if v, err := url.QueryUnescape(val); err == nil { + val = v } } - return + ps.params = ps.params[:paramIndex+1] + ps.params[paramIndex].Value = val + paramIndex++ + search = search[i:] + searchIndex = searchIndex + i + if search == nilString { + if cd := cn.findChild('/'); cd != nil && (cd.function != nil || cd.anyChild != nil) { + tsr = true + } + } + continue + } + Any: + // Any node + if child := cn.anyChild; child != nil { + // If any node is found, use remaining path for paramValues + cn = child + if ps == nil { + ps = params() + } + index := len(cn.pnames) - 1 + val := search + if unescape { + if v, err := url.QueryUnescape(val); err == nil { + val = v + } + } + ps.params = ps.params[:paramIndex+1] + ps.params[index].Value = val + // update indexes/search in case we need to backtrack when no handler match is found + paramIndex++ + searchIndex += len(search) + search = nilString + function = cn.function + break } - // Nothing found. We can recommend to redirect to the same URL with an - // extra trailing slash if a leaf exists for that path - tsr = (path == "/") || - (len(prefix) == len(path)+1 && prefix[len(path)] == '/' && - path == prefix[:len(prefix)-1] && n.function != nil) - return + // Let's backtrack to the first possible alternative node of the decision path + nk, ok := backtrackToNextNodeType(catchAll) + if !ok { + break // No other possibilities on the decision path + } else if nk == param { + goto Param + } else if nk == catchAll { + goto Any + } else { + // Not found + break + } + } + + if cn != nil { + for i, name := range cn.pnames { + ps.params[i].Key = name + } + } + + return +} + +func (n *node) findChild(l byte) *node { + for _, c := range n.children { + if c.label == l { + return c + } + } + return nil +} + +func (n *node) findChildWithLabel(l byte) *node { + for _, c := range n.children { + if c.label == l { + return c + } + } + if l == paramLabel { + return n.paramChild + } + if l == anyLabel { + return n.anyChild + } + return nil +} + +func newNode(t nodeType, pre string, p *node, child children, f *FunctionDescriptor, ppath string, pnames []string, paramChildren, anyChildren *node) *node { + return &node{ + nType: t, + label: pre[0], + prefix: pre, + parent: p, + children: child, + ppath: ppath, + pnames: pnames, + function: f, + paramChild: paramChildren, + anyChild: anyChildren, + isLeaf: child == nil && paramChildren == nil && anyChildren == nil, } } diff --git a/pkg/generic/descriptor/tree_test.go b/pkg/generic/descriptor/tree_test.go index 5b20ab2670..5566df84df 100644 --- a/pkg/generic/descriptor/tree_test.go +++ b/pkg/generic/descriptor/tree_test.go @@ -10,22 +10,11 @@ package descriptor import ( - "fmt" "reflect" - "regexp" "strings" "testing" ) -// func printChildren(n *node, prefix string) { -// fmt.Printf(" %02d %s%s[%d] %v %t %d \r\n", n.priority, prefix, n.path, len(n.children), n.handle, n.wildChild, n.nType) -// for l := len(n.path); l > 0; l-- { -// prefix += " " -// } -// for _, child := range n.children { -// printChildren(child, prefix) -// } -// } func fakeHandler(val string) *FunctionDescriptor { return &FunctionDescriptor{Name: val} } @@ -43,9 +32,13 @@ func getParams() *Params { } } -func checkRequests(t *testing.T, tree *node, requests testRequests) { +func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { + unescape := false + if len(unescapes) >= 1 { + unescape = unescapes[0] + } for _, request := range requests { - handler, psp, _ := tree.getValue(request.path, getParams) + handler, psp, _ := tree.getValue(request.path, getParams, unescape) switch { case handler == nil: @@ -71,32 +64,43 @@ func checkRequests(t *testing.T, tree *node, requests testRequests) { } } -func checkPriorities(t *testing.T, n *node) uint32 { - var prio uint32 - for i := range n.children { - prio += checkPriorities(t, n.children[i]) +func TestCountParams(t *testing.T) { + if countParams("/path/:param1/static/*catch-all") != 2 { + t.Fail() } - - if n.function != nil { - prio++ + if countParams(strings.Repeat("/:param", 256)) != 256 { + t.Fail() } +} - if n.priority != prio { - t.Errorf( - "priority mismatch for node '%s': is %d, should be %d", - n.path, n.priority, prio, - ) - } +func TestNoFunction(t *testing.T) { + tree := &node{} - return prio + route := "/hi" + recv := catchPanic(func() { + tree.addRoute(route, nil) + }) + if recv == nil { + t.Fatalf("no panic while inserting route with empty function '%s", route) + } } -func TestCountParams(t *testing.T) { - if countParams("/path/:param1/static/*catch-all") != 2 { - t.Fail() +func TestEmptyPath(t *testing.T) { + tree := &node{} + + routes := [...]string{ + "", + "user", + ":user", + "*user", } - if countParams(strings.Repeat("/:param", 256)) != 256 { - t.Fail() + for _, route := range routes { + recv := catchPanic(func() { + tree.addRoute(route, nil) + }) + if recv == nil { + t.Fatalf("no panic while inserting route with empty path '%s", route) + } } } @@ -120,9 +124,9 @@ func TestTreeAddAndGet(t *testing.T) { tree.addRoute(route, fakeHandler(route)) } - // printChildren(tree, "") - checkRequests(t, tree, testRequests{ + {"", true, "", nil}, + {"a", true, "", nil}, {"/a", false, "/a", nil}, {"/", true, "", nil}, {"/hi", false, "/hi", nil}, @@ -135,8 +139,6 @@ func TestTreeAddAndGet(t *testing.T) { {"/α", false, "/α", nil}, {"/β", false, "/β", nil}, }) - - checkPriorities(t, tree) } func TestTreeWildcard(t *testing.T) { @@ -146,6 +148,7 @@ func TestTreeWildcard(t *testing.T) { "/", "/cmd/:tool/:sub", "/cmd/:tool/", + "/cmd/xxx/", "/src/*filepath", "/search/", "/search/:query", @@ -157,31 +160,68 @@ func TestTreeWildcard(t *testing.T) { "/doc/go1.html", "/info/:user/public", "/info/:user/project/:project", + "/a/b/:c", + "/a/:b/c/d", + "/a/*b", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } - // printChildren(tree, "") - checkRequests(t, tree, testRequests{ {"/", false, "/", nil}, {"/cmd/test/", false, "/cmd/:tool/", &Params{params: []Param{{"tool", "test"}}}}, - {"/cmd/test", true, "", &Params{params: []Param{{"tool", "test"}}}}, + {"/cmd/test", true, "", &Params{params: []Param{}}}, {"/cmd/test/3", false, "/cmd/:tool/:sub", &Params{params: []Param{{"tool", "test"}, {"sub", "3"}}}}, - {"/src/", false, "/src/*filepath", &Params{params: []Param{{"filepath", "/"}}}}, - {"/src/some/file.png", false, "/src/*filepath", &Params{params: []Param{{"filepath", "/some/file.png"}}}}, + {"/src/", false, "/src/*filepath", &Params{params: []Param{{"filepath", ""}}}}, + {"/src/some/file.png", false, "/src/*filepath", &Params{params: []Param{{"filepath", "some/file.png"}}}}, {"/search/", false, "/search/", nil}, {"/search/someth!ng+in+ünìcodé", false, "/search/:query", &Params{params: []Param{{"query", "someth!ng+in+ünìcodé"}}}}, - {"/search/someth!ng+in+ünìcodé/", true, "", &Params{params: []Param{{"query", "someth!ng+in+ünìcodé"}}}}, + {"/search/someth!ng+in+ünìcodé/", true, "", &Params{params: []Param{}}}, {"/user_gopher", false, "/user_:name", &Params{params: []Param{{"name", "gopher"}}}}, {"/user_gopher/about", false, "/user_:name/about", &Params{params: []Param{{"name", "gopher"}}}}, - {"/files/js/inc/framework.js", false, "/files/:dir/*filepath", &Params{params: []Param{{"dir", "js"}, {"filepath", "/inc/framework.js"}}}}, + {"/files/js/inc/framework.js", false, "/files/:dir/*filepath", &Params{params: []Param{{"dir", "js"}, {"filepath", "inc/framework.js"}}}}, {"/info/gordon/public", false, "/info/:user/public", &Params{params: []Param{{"user", "gordon"}}}}, {"/info/gordon/project/go", false, "/info/:user/project/:project", &Params{params: []Param{{"user", "gordon"}, {"project", "go"}}}}, + {"/a/b/c", false, "/a/b/:c", &Params{params: []Param{{Key: "c", Value: "c"}}}}, + {"/a/b/c/d", false, "/a/:b/c/d", &Params{params: []Param{{Key: "b", Value: "b"}}}}, + {"/a/b", false, "/a/*b", &Params{params: []Param{{Key: "b", Value: "b"}}}}, }) +} - checkPriorities(t, tree) +func TestUnescapeParameters(t *testing.T) { + tree := &node{} + + routes := [...]string{ + "/", + "/cmd/:tool/:sub", + "/cmd/:tool/", + "/src/*filepath", + "/search/:query", + "/files/:dir/*filepath", + "/info/:user/project/:project", + "/info/:user", + } + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + unescape := true + checkRequests(t, tree, testRequests{ + {"/", false, "/", nil}, + {"/cmd/test/", false, "/cmd/:tool/", &Params{params: []Param{{"tool", "test"}}}}, + {"/cmd/test", true, "", &Params{params: []Param{}}}, + {"/src/some/file.png", false, "/src/*filepath", &Params{params: []Param{{"filepath", "some/file.png"}}}}, + {"/src/some/file+test.png", false, "/src/*filepath", &Params{params: []Param{{"filepath", "some/file test.png"}}}}, + {"/src/some/file++++%%%%test.png", false, "/src/*filepath", &Params{params: []Param{{"filepath", "some/file++++%%%%test.png"}}}}, + {"/src/some/file%2Ftest.png", false, "/src/*filepath", &Params{params: []Param{{"filepath", "some/file/test.png"}}}}, + {"/search/someth!ng+in+ünìcodé", false, "/search/:query", &Params{params: []Param{{"query", "someth!ng in ünìcodé"}}}}, + {"/info/gordon/project/go", false, "/info/:user/project/:project", &Params{params: []Param{{"user", "gordon"}, {"project", "go"}}}}, + {"/info/slash%2Fgordon", false, "/info/:user", &Params{params: []Param{{"user", "slash/gordon"}}}}, + {"/info/slash%2Fgordon/project/Project%20%231", false, "/info/:user/project/:project", &Params{params: []Param{{"user", "slash/gordon"}, {"project", "Project #1"}}}}, + {"/info/slash%%%%", false, "/info/:user", &Params{params: []Param{{"user", "slash%%%%"}}}}, + {"/info/slash%%%%2Fgordon/project/Project%%%%20%231", false, "/info/:user/project/:project", &Params{params: []Param{{"user", "slash%%%%2Fgordon"}, {"project", "Project%%%%20%231"}}}}, + }, unescape) } func catchPanic(testFunc func()) (recv interface{}) { @@ -204,7 +244,7 @@ func testRoutes(t *testing.T, routes []testRoute) { for i := range routes { route := routes[i] recv := catchPanic(func() { - tree.addRoute(route.path, nil) + tree.addRoute(route.path, fakeHandler(route.path)) }) if route.conflict { @@ -215,27 +255,25 @@ func testRoutes(t *testing.T, routes []testRoute) { t.Errorf("unexpected panic for route '%s': %v", route.path, recv) } } - - // printChildren(tree, "") } func TestTreeWildcardConflict(t *testing.T) { routes := []testRoute{ {"/cmd/:tool/:sub", false}, - {"/cmd/vet", true}, + {"/cmd/vet", false}, {"/src/*filepath", false}, {"/src/*filepathx", true}, - {"/src/", true}, + {"/src/", false}, {"/src1/", false}, - {"/src1/*filepath", true}, + {"/src1/*filepath", false}, {"/src2*filepath", true}, {"/search/:query", false}, - {"/search/invalid", true}, - {"/user_:name", false}, - {"/user_x", true}, + {"/search/invalid", false}, {"/user_:name", false}, + {"/user_x", false}, + {"/user_:name", true}, {"/id:id", false}, - {"/id/:id", true}, + {"/id/:id", false}, } testRoutes(t, routes) } @@ -245,13 +283,13 @@ func TestTreeChildConflict(t *testing.T) { {"/cmd/vet", false}, {"/cmd/:tool/:sub", false}, {"/src/AUTHORS", false}, - {"/src/*filepath", true}, + {"/src/*filepath", false}, {"/user_x", false}, {"/user_:name", false}, {"/id/:id", false}, {"/id:id", false}, {"/:id", false}, - {"/*filepath", true}, + {"/*filepath", false}, } testRoutes(t, routes) } @@ -277,19 +315,17 @@ func TestTreeDuplicatePath(t *testing.T) { // Add again recv = catchPanic(func() { - tree.addRoute(route, nil) + tree.addRoute(route, fakeHandler(route)) }) if recv == nil { t.Fatalf("no panic while inserting duplicate route '%s", route) } } - // printChildren(tree, "") - checkRequests(t, tree, testRequests{ {"/", false, "/", nil}, {"/doc/", false, "/doc/", nil}, - {"/src/some/file.png", false, "/src/*filepath", &Params{params: []Param{{"filepath", "/some/file.png"}}}}, + {"/src/some/file.png", false, "/src/*filepath", &Params{params: []Param{{"filepath", "some/file.png"}}}}, {"/search/someth!ng+in+ünìcodé", false, "/search/:query", &Params{params: []Param{{"query", "someth!ng+in+ünìcodé"}}}}, {"/user_gopher", false, "/user_:name", &Params{params: []Param{{"name", "gopher"}}}}, }) @@ -326,14 +362,6 @@ func TestTreeCatchAllConflict(t *testing.T) { testRoutes(t, routes) } -func TestTreeCatchAllConflictRoot(t *testing.T) { - routes := []testRoute{ - {"/", false}, - {"/*filepath", true}, - } - testRoutes(t, routes) -} - func TestTreeCatchMaxParams(t *testing.T) { tree := &node{} route := "/cmd/*filepath" @@ -390,7 +418,14 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/no/a", "/no/b", "/api/hello/:name", - "/vendor/:x/*y", + "/user/:name/*id", + "/resource", + "/r/*id", + "/book/biz/:name", + "/book/biz/abc", + "/book/biz/abc/bar", + "/book/:page/:name", + "/book/hello/:name/biz/", } for i := range routes { route := routes[i] @@ -402,8 +437,6 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { } } - // printChildren(tree, "") - tsrRoutes := [...]string{ "/hi/", "/b", @@ -419,10 +452,14 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/admin/config/", "/admin/config/permissions/", "/doc/", - "/vendor/x", + "/user/name", + "/r", + "/book/hello/a/biz", + "/book/biz/foo/", + "/book/biz/abc/bar/", } for _, route := range tsrRoutes { - handler, _, tsr := tree.getValue(route, nil) + handler, _, tsr := tree.getValue(route, getParams, false) if handler != nil { t.Fatalf("non-nil handler for TSR route '%s", route) } else if !tsr { @@ -437,9 +474,57 @@ func TestTreeTrailingSlashRedirect(t *testing.T) { "/_", "/_/", "/api/world/abc", + "/book", + "/book/", + "/book/hello/a/abc", + "/book/biz/abc/biz", } for _, route := range noTsrRoutes { - handler, _, tsr := tree.getValue(route, nil) + handler, _, tsr := tree.getValue(route, getParams, false) + if handler != nil { + t.Fatalf("non-nil handler for No-TSR route '%s", route) + } else if tsr { + t.Errorf("expected no TSR recommendation for route '%s'", route) + } + } +} + +func TestTreeTrailingSlashRedirect2(t *testing.T) { + tree := &node{} + + routes := [...]string{ + "/api/:version/seller/locales/get", + "/api/v:version/seller/permissions/get", + "/api/v:version/seller/university/entrance_knowledge_list/get", + } + for _, route := range routes { + recv := catchPanic(func() { + tree.addRoute(route, fakeHandler(route)) + }) + if recv != nil { + t.Fatalf("panic inserting route '%s': %v", route, recv) + } + } + + tsrRoutes := [...]string{ + "/api/v:version/seller/permissions/get/", + "/api/version/seller/permissions/get/", + } + + for _, route := range tsrRoutes { + handler, _, tsr := tree.getValue(route, getParams, false) + if handler != nil { + t.Fatalf("non-nil handler for TSR route '%s", route) + } else if !tsr { + t.Errorf("expected TSR recommendation for route '%s'", route) + } + } + + noTsrRoutes := [...]string{ + "/api/v:version/seller/permissions/get/a", + } + for _, route := range noTsrRoutes { + handler, _, tsr := tree.getValue(route, getParams, false) if handler != nil { t.Fatalf("non-nil handler for No-TSR route '%s", route) } else if tsr { @@ -458,71 +543,10 @@ func TestTreeRootTrailingSlashRedirect(t *testing.T) { t.Fatalf("panic inserting test route: %v", recv) } - handler, _, tsr := tree.getValue("/", nil) + handler, _, tsr := tree.getValue("/", nil, false) if handler != nil { t.Fatalf("non-nil handler") } else if tsr { t.Errorf("expected no TSR recommendation") } } - -func TestTreeInvalidNodeType(t *testing.T) { - const panicMsg = "invalid node type" - - tree := &node{} - tree.addRoute("/", fakeHandler("/")) - tree.addRoute("/:page", fakeHandler("/:page")) - - // set invalid node type - tree.children[0].nType = 42 - - // normal lookup - recv := catchPanic(func() { - tree.getValue("/test", nil) - }) - if rs, ok := recv.(string); !ok || rs != panicMsg { - t.Fatalf("Expected panic '"+panicMsg+"', got '%v'", recv) - } -} - -func TestTreeWildcardConflictEx(t *testing.T) { - conflicts := [...]struct { - route string - segPath string - existPath string - existSegPath string - }{ - {"/who/are/foo", "/foo", `/who/are/\*you`, `/\*you`}, - {"/who/are/foo/", "/foo/", `/who/are/\*you`, `/\*you`}, - {"/who/are/foo/bar", "/foo/bar", `/who/are/\*you`, `/\*you`}, - {"/conxxx", "xxx", `/con:tact`, `:tact`}, - {"/conooo/xxx", "ooo", `/con:tact`, `:tact`}, - } - - for i := range conflicts { - conflict := conflicts[i] - - // I have to re-create a 'tree', because the 'tree' will be - // in an inconsistent state when the loop recovers from the - // panic which threw by 'addRoute' function. - tree := &node{} - routes := [...]string{ - "/con:tact", - "/who/are/*you", - "/who/foo/hello", - } - - for i := range routes { - route := routes[i] - tree.addRoute(route, fakeHandler(route)) - } - - recv := catchPanic(func() { - tree.addRoute(conflict.route, fakeHandler(conflict.route)) - }) - - if !regexp.MustCompile(fmt.Sprintf("'%s' in new path .* conflicts with existing wildcard '%s' in existing prefix '%s'", conflict.segPath, conflict.existSegPath, conflict.existPath)).MatchString(fmt.Sprint(recv)) { - t.Fatalf("invalid wildcard conflict error (%v)", recv) - } - } -} diff --git a/pkg/generic/generic_service.go b/pkg/generic/generic_service.go index e39c6308d1..3b507665ce 100644 --- a/pkg/generic/generic_service.go +++ b/pkg/generic/generic_service.go @@ -124,6 +124,11 @@ func (g *Args) Read(ctx context.Context, method string, in thrift.TProtocol) err return fmt.Errorf("unexpected Args reader type: %T", g.inner) } +// GetFirstArgument implements util.KitexArgs. +func (g *Args) GetFirstArgument() interface{} { + return g.Request +} + // Result generic response type Result struct { Success interface{} @@ -159,7 +164,7 @@ func (r *Result) Read(ctx context.Context, method string, in thrift.TProtocol) e return fmt.Errorf("unexpected Result reader type: %T", r.inner) } -// GetSuccess ... +// GetSuccess implements util.KitexResult. func (r *Result) GetSuccess() interface{} { if !r.IsSetSuccess() { return nil @@ -167,7 +172,7 @@ func (r *Result) GetSuccess() interface{} { return r.Success } -// SetSuccess ... +// SetSuccess implements util.KitexResult. func (r *Result) SetSuccess(x interface{}) { r.Success = x } @@ -176,3 +181,8 @@ func (r *Result) SetSuccess(x interface{}) { func (r *Result) IsSetSuccess() bool { return r.Success != nil } + +// GetResult ... +func (r *Result) GetResult() interface{} { + return r.Success +} diff --git a/pkg/generic/generic_service_test.go b/pkg/generic/generic_service_test.go index 2e1dbe82a1..95eba59baf 100644 --- a/pkg/generic/generic_service_test.go +++ b/pkg/generic/generic_service_test.go @@ -30,6 +30,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote" codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/utils" ) func TestGenericService(t *testing.T) { @@ -124,3 +125,13 @@ func TestServiceInfo(t *testing.T) { s := ServiceInfo(serviceinfo.Thrift) test.Assert(t, s.ServiceName == "$GenericService") } + +func TestArgsResult(t *testing.T) { + var args interface{} = &Args{} + _, kaOK := args.(utils.KitexArgs) + test.Assert(t, kaOK) + + var result interface{} = &Result{} + _, krOK := result.(utils.KitexResult) + test.Assert(t, krOK) +} diff --git a/pkg/generic/http_test/generic_test.go b/pkg/generic/http_test/generic_test.go index e66e2aebb7..418b144238 100644 --- a/pkg/generic/http_test/generic_test.go +++ b/pkg/generic/http_test/generic_test.go @@ -92,6 +92,7 @@ func testThriftNormalBinaryEcho(t *testing.T) { gr, ok := resp.(*generic.HTTPResponse) test.Assert(t, ok) test.Assert(t, gr.Body["msg"] == base64.StdEncoding.EncodeToString([]byte(mockMyMsg))) + test.Assert(t, gr.Body["num"] == "0") // string value for binary field which should fail body = map[string]interface{}{ @@ -117,6 +118,30 @@ func testThriftNormalBinaryEcho(t *testing.T) { test.Assert(t, ok) test.Assert(t, gr.Body["msg"] == mockMyMsg) + // []byte value for binary field + body = map[string]interface{}{ + "msg": []byte(mockMyMsg), + "got_base64": true, + } + data, err = json.Marshal(body) + if err != nil { + panic(err) + } + req, err = http.NewRequest(http.MethodGet, url, bytes.NewBuffer(data)) + if err != nil { + panic(err) + } + customReq, err = generic.FromHTTPRequest(req) + if err != nil { + t.Fatal(err) + } + resp, err = cli.GenericCall(context.Background(), "", customReq, callopt.WithRPCTimeout(100*time.Second)) + test.Assert(t, err == nil, err) + gr, ok = resp.(*generic.HTTPResponse) + test.Assert(t, ok) + test.Assert(t, gr.Body["msg"] == base64.StdEncoding.EncodeToString([]byte(mockMyMsg))) + test.Assert(t, gr.Body["num"] == "0") + body = map[string]interface{}{ "msg": []byte(mockMyMsg), "got_base64": true, diff --git a/pkg/generic/http_test/idl/binary_echo.thrift b/pkg/generic/http_test/idl/binary_echo.thrift index b6bb918920..8034438e6d 100644 --- a/pkg/generic/http_test/idl/binary_echo.thrift +++ b/pkg/generic/http_test/idl/binary_echo.thrift @@ -3,7 +3,8 @@ namespace go kitex.test.server struct BinaryWrapper { 1: binary msg (api.body = "msg") 2: bool got_base64 (api.body = "got_base64") - 3: i64 num (api.body = "num", api.js_conv="") + 3: required i64 num (api.body = "num", api.js_conv="") + 4: optional string str (api.body = "str") } service ExampleService { diff --git a/pkg/generic/httppbthrift_codec_test.go b/pkg/generic/httppbthrift_codec_test.go new file mode 100644 index 0000000000..4be52e9886 --- /dev/null +++ b/pkg/generic/httppbthrift_codec_test.go @@ -0,0 +1,42 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generic + +import ( + "bytes" + "io/ioutil" + "net/http" + "reflect" + "testing" + + "github.com/bytedance/mockey" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestFromHTTPPbRequest(t *testing.T) { + mockey.PatchConvey("TestFromHTTPPbRequest", t, func() { + req, err := http.NewRequest("POST", "/far/boo", bytes.NewBuffer([]byte("321"))) + test.Assert(t, err == nil) + mockey.Mock(ioutil.ReadAll).Return([]byte("123"), nil).Build() + hreq, err := FromHTTPPbRequest(req) + test.Assert(t, err == nil) + test.Assert(t, reflect.DeepEqual(hreq.RawBody, []byte("123")), string(hreq.RawBody)) + test.Assert(t, hreq.Method == "POST") + test.Assert(t, hreq.Path == "/far/boo") + }) +} diff --git a/pkg/generic/json_test/generic_init.go b/pkg/generic/json_test/generic_init.go index 0e4d1c39a2..fa0b289222 100644 --- a/pkg/generic/json_test/generic_init.go +++ b/pkg/generic/json_test/generic_init.go @@ -115,17 +115,7 @@ func (g *GenericServiceVoidImpl) GenericCall(ctx context.Context, method string, return descriptor.Void{}, nil } -// GenericServiceReadRequiredFiledImpl ... -type GenericServiceReadRequiredFiledImpl struct{} - -// GenericCall ... -func (g *GenericServiceReadRequiredFiledImpl) GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) { - msg := request.(string) - fmt.Printf("Recv: %s\n", msg) - return `{"Msg":"world"}`, nil -} - -// GenericServiceReadRequiredFiledImpl ... +// GenericServiceBinaryEchoImpl ... type GenericServiceBinaryEchoImpl struct{} const mockMyMsg = "my msg" diff --git a/pkg/generic/json_test/generic_test.go b/pkg/generic/json_test/generic_test.go index 9f3f375d57..100c7fd2e5 100644 --- a/pkg/generic/json_test/generic_test.go +++ b/pkg/generic/json_test/generic_test.go @@ -44,8 +44,6 @@ func TestRun(t *testing.T) { t.Run("TestThriftError", testThriftError) t.Run("TestThriftOnewayMethod", testThriftOnewayMethod) t.Run("TestThriftVoidMethod", testThriftVoidMethod) - t.Run("TestThriftReadRequiredField", testThriftReadRequiredField) - t.Run("TestThriftWriteRequiredField", testThriftWriteRequiredField) t.Run("TestThrift2NormalServer", testThrift2NormalServer) t.Run("TestJSONThriftGenericClientClose", TestJSONThriftGenericClientClose) t.Run("TestThriftRawBinaryEcho", testThriftRawBinaryEcho) @@ -122,32 +120,6 @@ func testThriftVoidMethod(t *testing.T) { svr.Stop() } -func testThriftReadRequiredField(t *testing.T) { - time.Sleep(1 * time.Second) - svr := initThriftServer(t, ":8126", new(GenericServiceReadRequiredFiledImpl)) - time.Sleep(500 * time.Millisecond) - - cli := initThriftClientByIDL(t, "127.0.0.1:8126", "./idl/example_check_read_required.thrift", nil) - - _, err := cli.GenericCall(context.Background(), "ExampleMethod", reqMsg, callopt.WithRPCTimeout(100*time.Second)) - test.Assert(t, err != nil, err) - test.Assert(t, strings.Contains(err.Error(), "required field (2/required_field) missing"), err.Error()) - svr.Stop() -} - -func testThriftWriteRequiredField(t *testing.T) { - time.Sleep(1 * time.Second) - svr := initThriftServer(t, ":8127", new(GenericServiceReadRequiredFiledImpl)) - time.Sleep(500 * time.Millisecond) - - cli := initThriftClientByIDL(t, "127.0.0.1:8127", "./idl/example_check_write_required.thrift", nil) - - _, err := cli.GenericCall(context.Background(), "ExampleMethod", reqMsg, callopt.WithRPCTimeout(100*time.Second)) - test.Assert(t, err != nil, err) - test.Assert(t, strings.Contains(err.Error(), "required field (2/Foo) missing"), err.Error()) - svr.Stop() -} - func testThrift2NormalServer(t *testing.T) { time.Sleep(4 * time.Second) svr := initMockServer(t, new(mockImpl)) diff --git a/pkg/generic/json_test/idl/binary_echo.thrift b/pkg/generic/json_test/idl/binary_echo.thrift index 8e27a013ed..ac7c937fdf 100644 --- a/pkg/generic/json_test/idl/binary_echo.thrift +++ b/pkg/generic/json_test/idl/binary_echo.thrift @@ -3,6 +3,7 @@ namespace go kitex.test.server struct BinaryWrapper { 1: binary msg 2: bool got_base64 + 3: optional string str } service ExampleService { diff --git a/pkg/generic/json_test/idl/example_check_read_required.thrift b/pkg/generic/json_test/idl/example_check_read_required.thrift deleted file mode 100644 index 6dfdb1ce89..0000000000 --- a/pkg/generic/json_test/idl/example_check_read_required.thrift +++ /dev/null @@ -1,35 +0,0 @@ -include "base.thrift" -include "self_ref.thrift" -namespace go kitex.test.server - -enum FOO { - A = 1; -} - -struct ExampleReq { - 1: required string Msg, - 2: FOO Foo, - 255: base.Base Base, -} -struct ExampleResp { - 1: required string Msg, - 2: required string required_field - 255: base.BaseResp BaseResp, -} -exception Exception { - 1: i32 code - 2: string msg -} - -struct A { - 1: A self - 2: self_ref.A a -} - -service ExampleService { - ExampleResp ExampleMethod(1: ExampleReq req)throws(1: Exception err), - A Foo(1: A req) - string Ping(1: string msg) - oneway void Oneway(1: string msg) - void Void(1: string msg) -} \ No newline at end of file diff --git a/pkg/generic/json_test/idl/example_check_write_required.thrift b/pkg/generic/json_test/idl/example_check_write_required.thrift deleted file mode 100644 index 5e341a99ac..0000000000 --- a/pkg/generic/json_test/idl/example_check_write_required.thrift +++ /dev/null @@ -1,36 +0,0 @@ -include "base.thrift" -include "self_ref.thrift" -namespace go kitex.test.server - -enum FOO { - A = 1; -} - -struct ExampleReq { - 1: required string Msg, - 2: required FOO Foo, - 255: base.Base Base, -} -struct ExampleResp { - 1: required string Msg, - 2: required string required_field - 3: string extra_field - 255: base.BaseResp BaseResp, -} -exception Exception { - 1: i32 code - 2: string msg -} - -struct A { - 1: A self - 2: self_ref.A a -} - -service ExampleService { - ExampleResp ExampleMethod(1: ExampleReq req)throws(1: Exception err), - A Foo(1: A req) - string Ping(1: string msg) - oneway void Oneway(1: string msg) - void Void(1: string msg) -} \ No newline at end of file diff --git a/pkg/generic/map_test/generic_init.go b/pkg/generic/map_test/generic_init.go index be3aab859b..4f4a56d417 100644 --- a/pkg/generic/map_test/generic_init.go +++ b/pkg/generic/map_test/generic_init.go @@ -119,18 +119,6 @@ func (g *GenericServiceVoidImpl) GenericCall(ctx context.Context, method string, return descriptor.Void{}, nil } -// GenericServiceReadRequiredFiledImpl ... -type GenericServiceReadRequiredFiledImpl struct{} - -// GenericCall ... -func (g *GenericServiceReadRequiredFiledImpl) GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) { - msg := request.(map[string]interface{}) - fmt.Printf("Recv: %v\n", msg) - return map[string]interface{}{ - "Msg": "world", - }, nil -} - var ( mockReq = map[string]interface{}{ "Msg": "hello", diff --git a/pkg/generic/map_test/generic_test.go b/pkg/generic/map_test/generic_test.go index 028dd44b72..96fbe4b028 100644 --- a/pkg/generic/map_test/generic_test.go +++ b/pkg/generic/map_test/generic_test.go @@ -69,6 +69,7 @@ func TestThrift(t *testing.T) { }, wantResp: map[string]interface{}{ "Msg": "hello", + "Foo": int32(0), "TestList": []interface{}{ map[string]interface{}{ "Bar": "foo", @@ -86,6 +87,12 @@ func TestThrift(t *testing.T) { int64(123), int64(456), }, "B": true, + "Base": map[string]interface{}{ + "LogID": "", + "Caller": "", + "Addr": "", + "Client": "", + }, }, }, { @@ -113,17 +120,22 @@ func TestThrift(t *testing.T) { }, wantResp: map[string]interface{}{ "Msg": "hello", + "Foo": int32(0), "TestList": []interface{}{ map[string]interface{}{ "Bar": "foo", }, - map[string]interface{}{}, + map[string]interface{}{ + "Bar": "", + }, }, "TestMap": map[interface{}]interface{}{ "l1": map[string]interface{}{ "Bar": "foo", }, - "l2": map[string]interface{}{}, + "l2": map[string]interface{}{ + "Bar": "", + }, }, "StrList": []interface{}{ "123", "", @@ -131,6 +143,13 @@ func TestThrift(t *testing.T) { "I64List": []interface{}{ int64(123), int64(0), }, + "B": false, + "Base": map[string]interface{}{ + "LogID": "", + "Caller": "", + "Addr": "", + "Client": "", + }, }, }, { @@ -152,11 +171,16 @@ func TestThrift(t *testing.T) { }, wantResp: map[string]interface{}{ "Msg": "hello", + "Foo": int32(0), "TestList": []interface{}{ - map[string]interface{}{}, + map[string]interface{}{ + "Bar": "", + }, }, "TestMap": map[interface{}]interface{}{ - "l2": map[string]interface{}{}, + "l2": map[string]interface{}{ + "Bar": "", + }, }, "StrList": []interface{}{ "", @@ -164,6 +188,13 @@ func TestThrift(t *testing.T) { "I64List": []interface{}{ int64(0), }, + "B": false, + "Base": map[string]interface{}{ + "LogID": "", + "Caller": "", + "Addr": "", + "Client": "", + }, }, }, { @@ -191,17 +222,22 @@ func TestThrift(t *testing.T) { }, wantResp: map[string]interface{}{ "Msg": "hello", + "Foo": int32(0), "TestList": []interface{}{ map[string]interface{}{ "Bar": "foo", }, - map[string]interface{}{}, + map[string]interface{}{ + "Bar": "", + }, }, "TestMap": map[interface{}]interface{}{ "l1": map[string]interface{}{ "Bar": "foo", }, - "l2": map[string]interface{}{}, + "l2": map[string]interface{}{ + "Bar": "", + }, }, "StrList": []interface{}{ "123", "", @@ -209,6 +245,13 @@ func TestThrift(t *testing.T) { "I64List": []interface{}{ int64(123), int64(0), int64(456), }, + "B": false, + "Base": map[string]interface{}{ + "LogID": "", + "Caller": "", + "Addr": "", + "Client": "", + }, }, }, } @@ -276,32 +319,6 @@ func TestThriftVoidMethod(t *testing.T) { svr.Stop() } -func TestThriftReadRequiredField(t *testing.T) { - time.Sleep(1 * time.Second) - svr := initThriftServer(t, ":9026", new(GenericServiceReadRequiredFiledImpl)) - time.Sleep(500 * time.Millisecond) - - cli := initThriftClientByIDL(t, "127.0.0.1:9026", "./idl/example_check_read_required.thrift") - - _, err := cli.GenericCall(context.Background(), "ExampleMethod", reqMsg, callopt.WithRPCTimeout(100*time.Second)) - test.Assert(t, err != nil, err) - test.Assert(t, strings.Contains(err.Error(), "required field (2/required_field) missing"), err.Error()) - svr.Stop() -} - -func TestThriftWriteRequiredField(t *testing.T) { - time.Sleep(1 * time.Second) - svr := initThriftServer(t, ":9027", new(GenericServiceReadRequiredFiledImpl)) - time.Sleep(500 * time.Millisecond) - - cli := initThriftClientByIDL(t, "127.0.0.1:9027", "./idl/example_check_write_required.thrift") - - _, err := cli.GenericCall(context.Background(), "ExampleMethod", reqMsg, callopt.WithRPCTimeout(100*time.Second)) - test.Assert(t, err != nil, err) - test.Assert(t, strings.Contains(err.Error(), "required field (2/Foo) missing"), err.Error()) - svr.Stop() -} - func TestThrift2NormalServer(t *testing.T) { time.Sleep(4 * time.Second) svr := initMockServer(t, new(mockImpl)) diff --git a/pkg/generic/map_test/idl/example.thrift b/pkg/generic/map_test/idl/example.thrift index 5bf8bb1e2c..ceb61dfe05 100644 --- a/pkg/generic/map_test/idl/example.thrift +++ b/pkg/generic/map_test/idl/example.thrift @@ -15,12 +15,12 @@ struct MockElem { } struct ExampleReq { - 1: required string Msg, + 1: required string Msg = "Hello", 2: FOO Foo, 3: list TestList, - 4: map TestMap, + 4: optional map TestMap, 5: list StrList, - 6: list I64List, + 6: list I64List = [1, 2, 3], 7: bool B, 255: base.Base Base, } diff --git a/pkg/generic/map_test/idl/example_check_read_required.thrift b/pkg/generic/map_test/idl/example_check_read_required.thrift deleted file mode 100644 index 6dfdb1ce89..0000000000 --- a/pkg/generic/map_test/idl/example_check_read_required.thrift +++ /dev/null @@ -1,35 +0,0 @@ -include "base.thrift" -include "self_ref.thrift" -namespace go kitex.test.server - -enum FOO { - A = 1; -} - -struct ExampleReq { - 1: required string Msg, - 2: FOO Foo, - 255: base.Base Base, -} -struct ExampleResp { - 1: required string Msg, - 2: required string required_field - 255: base.BaseResp BaseResp, -} -exception Exception { - 1: i32 code - 2: string msg -} - -struct A { - 1: A self - 2: self_ref.A a -} - -service ExampleService { - ExampleResp ExampleMethod(1: ExampleReq req)throws(1: Exception err), - A Foo(1: A req) - string Ping(1: string msg) - oneway void Oneway(1: string msg) - void Void(1: string msg) -} \ No newline at end of file diff --git a/pkg/generic/map_test/idl/example_check_write_required.thrift b/pkg/generic/map_test/idl/example_check_write_required.thrift deleted file mode 100644 index 5e341a99ac..0000000000 --- a/pkg/generic/map_test/idl/example_check_write_required.thrift +++ /dev/null @@ -1,36 +0,0 @@ -include "base.thrift" -include "self_ref.thrift" -namespace go kitex.test.server - -enum FOO { - A = 1; -} - -struct ExampleReq { - 1: required string Msg, - 2: required FOO Foo, - 255: base.Base Base, -} -struct ExampleResp { - 1: required string Msg, - 2: required string required_field - 3: string extra_field - 255: base.BaseResp BaseResp, -} -exception Exception { - 1: i32 code - 2: string msg -} - -struct A { - 1: A self - 2: self_ref.A a -} - -service ExampleService { - ExampleResp ExampleMethod(1: ExampleReq req)throws(1: Exception err), - A Foo(1: A req) - string Ping(1: string msg) - oneway void Oneway(1: string msg) - void Void(1: string msg) -} \ No newline at end of file diff --git a/pkg/generic/thrift/parse.go b/pkg/generic/thrift/parse.go index b4b487a465..2a438cb637 100644 --- a/pkg/generic/thrift/parse.go +++ b/pkg/generic/thrift/parse.go @@ -19,11 +19,13 @@ package thrift import ( "errors" "fmt" + "runtime/debug" "github.com/cloudwego/thriftgo/parser" "github.com/cloudwego/thriftgo/semantic" "github.com/cloudwego/kitex/pkg/generic/descriptor" + "github.com/cloudwego/kitex/pkg/klog" ) const ( @@ -141,7 +143,7 @@ func getAllFunctions(svc *parser.Service, tree *parser.Thrift, visitedSvcs map[* return ch } -func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.ServiceDescriptor, structsCache map[string]*descriptor.TypeDescriptor) error { +func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.ServiceDescriptor, structsCache map[string]*descriptor.TypeDescriptor) (err error) { if sDsc.Functions[fn.Name] != nil { return fmt.Errorf("duplicate method name: %s", fn.Name) } @@ -157,7 +159,8 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv FieldsByName: map[string]*descriptor.FieldDescriptor{}, }, } - reqType, err := parseType(field.Type, tree, structsCache, initRecursionDepth) + var reqType *descriptor.TypeDescriptor + reqType, err = parseType(field.Type, tree, structsCache, initRecursionDepth) if err != nil { return err } @@ -185,7 +188,8 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv FieldsByName: map[string]*descriptor.FieldDescriptor{}, }, } - respType, err := parseType(fn.FunctionType, tree, structsCache, initRecursionDepth) + var respType *descriptor.TypeDescriptor + respType, err = parseType(fn.FunctionType, tree, structsCache, initRecursionDepth) if err != nil { return err } @@ -199,7 +203,8 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv if len(fn.Throws) > 0 { // only support single exception field := fn.Throws[0] - exceptionType, err := parseType(field.Type, tree, structsCache, initRecursionDepth) + var exceptionType *descriptor.TypeDescriptor + exceptionType, err = parseType(field.Type, tree, structsCache, initRecursionDepth) if err != nil { return err } @@ -219,6 +224,12 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv Response: resp, HasRequestBase: hasRequestBase, } + defer func() { + if ret := recover(); ret != nil { + klog.Errorf("KITEX: router handle failed, err=%v\nstack=%s", ret, string(debug.Stack())) + err = fmt.Errorf("router handle failed, err=%v", ret) + } + }() for _, ann := range fn.Annotations { for _, v := range ann.GetValues() { if handle, ok := descriptor.FindAnnotation(ann.GetKey(), v); ok { diff --git a/pkg/generic/thrift/parse_test.go b/pkg/generic/thrift/parse_test.go index 4bf3939534..418be486bf 100644 --- a/pkg/generic/thrift/parse_test.go +++ b/pkg/generic/thrift/parse_test.go @@ -98,6 +98,32 @@ func TestParseHttpIDL(t *testing.T) { test.Assert(t, len(bizMethod1.Response.Struct.FieldsByID) == 2) } +var httpConflictPathIDL = ` +namespace go http + +struct BizRequest { + 1: optional i32 api_version(api.path = 'action') + 2: optional i64 uid(api.path = 'biz') +} + +struct BizResponse { + 1: optional string T(api.header= 'T') +} + +service BizService{ +BizResponse BizMethod1(1: BizRequest req)(api.post = '/life/client/:action/:biz/*one', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'form') +BizResponse BizMethod2(1: BizRequest req)(api.post = '/life/client/:action/:biz/*two', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'json') +BizResponse BizMethod3(1: BizRequest req)(api.post = '/life/client/:action/:biz/*three', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'json') +}` + +func TestPanicRecover(t *testing.T) { + re, err := parser.ParseString("http.thrift", httpConflictPathIDL) + test.Assert(t, err == nil, err) + _, err = Parse(re, LastServiceOnly) + test.Assert(t, err != nil) + test.DeepEqual(t, err.Error(), "router handle failed, err=handlers are already registered for path '/life/client/:action/:biz/*two'") +} + var selfReferenceIDL = ` namespace go http diff --git a/pkg/generic/thrift/write.go b/pkg/generic/thrift/write.go index 2fccf5f6da..cddfcc155b 100644 --- a/pkg/generic/thrift/write.go +++ b/pkg/generic/thrift/write.go @@ -684,10 +684,14 @@ func writeStruct(ctx context.Context, val interface{}, out thrift.TProtocol, t * continue } if !ok || elem == nil { - if field.Required { - return fmt.Errorf("required field (%d/%s) missing", field.ID, name) + if !field.Optional { + elem, _, err = getDefaultValueAndWriter(field.Type, opt) + if err != nil { + return fmt.Errorf("field (%d/%s) error: %w", field.ID, name, err) + } + } else { + continue } - continue } if field.ValueMapping != nil { elem, err = field.ValueMapping.Request(ctx, elem, field) @@ -737,10 +741,14 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out thrift.TProtocol continue } if v == nil { - if field.Required { - return fmt.Errorf("required field (%d/%s) missing", field.ID, name) + if !field.Optional { + v, _, err = getDefaultValueAndWriter(field.Type, opt) + if err != nil { + return fmt.Errorf("field (%d/%s) error: %w", field.ID, name, err) + } + } else { + continue } - continue } if field.ValueMapping != nil { if v, err = field.ValueMapping.Request(ctx, v, field); err != nil { @@ -784,10 +792,9 @@ func writeJSON(ctx context.Context, val interface{}, out thrift.TProtocol, t *de } if elem.Type == gjson.Null { - if field.Required { - return perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("required field (%d/%s) missing", field.ID, name)) + if field.Optional { + continue } - continue } v, writer, err := nextJSONWriter(&elem, field.Type, opt) diff --git a/pkg/generic/thrift/write_test.go b/pkg/generic/thrift/write_test.go index 90dfda3bfb..0b1317ce88 100644 --- a/pkg/generic/thrift/write_test.go +++ b/pkg/generic/thrift/write_test.go @@ -1039,6 +1039,47 @@ func Test_writeStruct(t *testing.T) { }, false, }, + { + "writeStructRequired", + args{ + val: map[string]interface{}{"hello": nil}, + out: mockTTransport, + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": {Name: "hello", ID: 1, Required: true, Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, + }, + RequiredFields: map[int32]*descriptor.FieldDescriptor{ + 1: {Name: "hello", ID: 1, Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, + }, + }, + }, + }, + false, + }, + { + "writeStructOptional", + args{ + val: map[string]interface{}{}, + out: mockTTransport, + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": {Name: "hello", ID: 1, Optional: true, DefaultValue: "Hello", Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, + }, + }, + }, + }, + false, + }, { "writeStructError", args{ @@ -1087,9 +1128,6 @@ func Test_writeHTTPRequest(t *testing.T) { return nil }, } - req := &descriptor.HTTPRequest{ - Body: map[string]interface{}{"hello": "world"}, - } tests := []struct { name string args args @@ -1099,7 +1137,89 @@ func Test_writeHTTPRequest(t *testing.T) { { "writeStruct", args{ - val: req, + val: &descriptor.HTTPRequest{ + Body: map[string]interface{}{"hello": "world"}, + }, + out: mockTTransport, + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": { + Name: "hello", + ID: 1, + Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + HTTPMapping: descriptor.DefaultNewMapping("hello"), + }, + }, + }, + }, + }, + false, + }, + { + "writeStructRequired", + args{ + val: &descriptor.HTTPRequest{ + Body: map[string]interface{}{"hello": nil}, + }, + out: mockTTransport, + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": { + Name: "hello", + ID: 1, + Required: true, + Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + HTTPMapping: descriptor.DefaultNewMapping("hello"), + }, + }, + }, + }, + }, + false, + }, + { + "writeStructDefault", + args{ + val: &descriptor.HTTPRequest{ + Body: map[string]interface{}{"hello": nil}, + }, + out: mockTTransport, + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": { + Name: "hello", + ID: 1, + DefaultValue: "world", + Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + HTTPMapping: descriptor.DefaultNewMapping("hello"), + }, + }, + }, + }, + }, + false, + }, + { + "writeStructOptional", + args{ + val: &descriptor.HTTPRequest{ + Body: map[string]interface{}{}, + }, out: mockTTransport, t: &descriptor.TypeDescriptor{ Type: descriptor.STRUCT, @@ -1111,6 +1231,7 @@ func Test_writeHTTPRequest(t *testing.T) { "hello": { Name: "hello", ID: 1, + Optional: true, Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}, HTTPMapping: descriptor.DefaultNewMapping("hello"), }, @@ -1327,6 +1448,7 @@ func Test_writeJSON(t *testing.T) { }, } data := gjson.Parse(`{"hello": "world"}`) + dataEmpty := gjson.Parse(`{"hello": nil}`) tests := []struct { name string args args @@ -1355,6 +1477,47 @@ func Test_writeJSON(t *testing.T) { }, false, }, + { + "writeJSONRequired", + args{ + val: &dataEmpty, + out: mockTTransport, + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": {Name: "hello", ID: 1, Required: true, Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, + }, + RequiredFields: map[int32]*descriptor.FieldDescriptor{ + 1: {Name: "hello", ID: 1, Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, + }, + }, + }, + }, + false, + }, + { + "writeJSONOptional", + args{ + val: &dataEmpty, + out: mockTTransport, + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": {Name: "hello", ID: 1, Optional: true, Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}}, + }, + }, + }, + }, + false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/loadbalance/iterator.go b/pkg/loadbalance/iterator.go new file mode 100644 index 0000000000..f07d1d655c --- /dev/null +++ b/pkg/loadbalance/iterator.go @@ -0,0 +1,41 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "sync/atomic" + + "github.com/bytedance/gopkg/lang/fastrand" +) + +// round implement a strict Round Robin algorithm +type round struct { + state uint64 // 8 bytes + _ [7]uint64 // + 7 * 8 bytes + // = 64 bytes +} + +func (r *round) Next() uint64 { + return atomic.AddUint64(&r.state, 1) +} + +func newRound() *round { + r := &round{ + state: fastrand.Uint64(), // every thread have a rand start order + } + return r +} diff --git a/pkg/loadbalance/lbcache/cache.go b/pkg/loadbalance/lbcache/cache.go index 8fb2ad75b6..f4d60bc6b6 100644 --- a/pkg/loadbalance/lbcache/cache.go +++ b/pkg/loadbalance/lbcache/cache.go @@ -31,6 +31,7 @@ import ( "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/loadbalance" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/utils" ) const ( @@ -54,6 +55,9 @@ type Options struct { // DiagnosisService is used register info for diagnosis DiagnosisService diagnosis.Service + + // Cacheable is used to indicate that if the factory could be shared between multi clients + Cacheable bool } func (v *Options) check() { @@ -91,29 +95,37 @@ func cacheKey(resolver, balancer string, opts Options) string { return fmt.Sprintf("%s|%s|{%s %s}", resolver, balancer, opts.RefreshInterval, opts.ExpireInterval) } +func newBalancerFactory(resolver discovery.Resolver, balancer loadbalance.Loadbalancer, opts Options) *BalancerFactory { + b := &BalancerFactory{ + opts: opts, + resolver: resolver, + balancer: balancer, + } + if rb, ok := balancer.(loadbalance.Rebalancer); ok { + hrb := newHookRebalancer(rb) + b.rebalancer = hrb + b.Hookable = hrb + } else { + b.Hookable = noopHookRebalancer{} + } + go b.watcher() + return b +} + // NewBalancerFactory get or create a balancer factory for balancer instance // cache key with resolver name, balancer name and options func NewBalancerFactory(resolver discovery.Resolver, balancer loadbalance.Loadbalancer, opts Options) *BalancerFactory { opts.check() + if !opts.Cacheable { + return newBalancerFactory(resolver, balancer, opts) + } uniqueKey := cacheKey(resolver.Name(), balancer.Name(), opts) val, ok := balancerFactories.Load(uniqueKey) if ok { return val.(*BalancerFactory) } val, _, _ = balancerFactoriesSfg.Do(uniqueKey, func() (interface{}, error) { - b := &BalancerFactory{ - opts: opts, - resolver: resolver, - balancer: balancer, - } - if rb, ok := balancer.(loadbalance.Rebalancer); ok { - hrb := newHookRebalancer(rb) - b.rebalancer = hrb - b.Hookable = hrb - } else { - b.Hookable = noopHookRebalancer{} - } - go b.watcher() + b := newBalancerFactory(resolver, balancer, opts) balancerFactories.Store(uniqueKey, b) return b, nil }) @@ -178,10 +190,10 @@ type Balancer struct { target string // a description returned from the resolver's Target method res atomic.Value // newest and previous discovery result expire int32 // 0 = normal, 1 = expire and collect next ticker - sharedTicker *sharedTicker + sharedTicker *utils.SharedTicker } -func (bl *Balancer) refresh() { +func (bl *Balancer) Refresh() { res, err := bl.b.resolver.Resolve(context.Background(), bl.target) if err != nil { klog.Warnf("KITEX: resolver refresh failed, key=%s error=%s", bl.target, err.Error()) @@ -198,6 +210,11 @@ func (bl *Balancer) refresh() { bl.res.Store(res) } +// Tick implements the interface utils.TickerTask. +func (bl *Balancer) Tick() { + bl.Refresh() +} + // GetResult returns the discovery result that the Balancer holds. func (bl *Balancer) GetResult() (res discovery.Result, ok bool) { if v := bl.res.Load(); v != nil { @@ -225,7 +242,7 @@ func (bl *Balancer) close() { }) } // delete from sharedTicker - bl.sharedTicker.delete(bl) + bl.sharedTicker.Delete(bl) } const unknown = "unknown" diff --git a/pkg/loadbalance/lbcache/cache_test.go b/pkg/loadbalance/lbcache/cache_test.go index 0c7eb79629..bee368b5cf 100644 --- a/pkg/loadbalance/lbcache/cache_test.go +++ b/pkg/loadbalance/lbcache/cache_test.go @@ -61,7 +61,7 @@ func TestBuilder(t *testing.T) { return picker }).AnyTimes() lb.EXPECT().Name().Return("Synthesized").AnyTimes() - NewBalancerFactory(r, lb, Options{}) + NewBalancerFactory(r, lb, Options{Cacheable: true}) b, ok := balancerFactories.Load(cacheKey(t.Name(), "Synthesized", defaultOptions)) test.Assert(t, ok) test.Assert(t, b != nil) diff --git a/pkg/loadbalance/lbcache/shared_ticker.go b/pkg/loadbalance/lbcache/shared_ticker.go index f648fe3348..2b4bffd63e 100644 --- a/pkg/loadbalance/lbcache/shared_ticker.go +++ b/pkg/loadbalance/lbcache/shared_ticker.go @@ -21,6 +21,8 @@ import ( "time" "golang.org/x/sync/singleflight" + + "github.com/cloudwego/kitex/pkg/utils" ) var ( @@ -29,84 +31,21 @@ var ( sharedTickersSfg singleflight.Group ) -// shared ticker -type sharedTicker struct { - sync.Mutex - started bool - interval time.Duration - tasks map[*Balancer]struct{} - stopChan chan struct{} -} - -func getSharedTicker(b *Balancer, refreshInterval time.Duration) *sharedTicker { +func getSharedTicker(b *Balancer, refreshInterval time.Duration) *utils.SharedTicker { sti, ok := sharedTickers.Load(refreshInterval) if ok { - st := sti.(*sharedTicker) - st.add(b) + st := sti.(*utils.SharedTicker) + st.Add(b) return st } v, _, _ := sharedTickersSfg.Do(refreshInterval.String(), func() (interface{}, error) { - st := &sharedTicker{ - interval: refreshInterval, - tasks: map[*Balancer]struct{}{}, - stopChan: make(chan struct{}, 1), - } + st := utils.NewSharedTicker(refreshInterval) sharedTickers.Store(refreshInterval, st) return st, nil }) - st := v.(*sharedTicker) - // add without singleflight, - // because we need all balancers those call this function to add themself to sharedTicker - st.add(b) + st := v.(*utils.SharedTicker) + // Add without singleflight, + // because we need all refreshers those call this function to add themselves to SharedTicker + st.Add(b) return st } - -func (t *sharedTicker) add(b *Balancer) { - t.Lock() - defer t.Unlock() - // add task - t.tasks[b] = struct{}{} - if !t.started { - t.started = true - go t.tick(t.interval) - } -} - -func (t *sharedTicker) delete(b *Balancer) { - t.Lock() - defer t.Unlock() - // delete from tasks - delete(t.tasks, b) - // no tasks remaining then stop the tick - if len(t.tasks) == 0 { - // unblocked when multi delete call - select { - case t.stopChan <- struct{}{}: - t.started = false - default: - } - } -} - -func (t *sharedTicker) tick(interval time.Duration) { - var wg sync.WaitGroup - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - t.Lock() - for b := range t.tasks { - wg.Add(1) - go func(b *Balancer) { - defer wg.Done() - b.refresh() - }(b) - } - wg.Wait() - t.Unlock() - case <-t.stopChan: - return - } - } -} diff --git a/pkg/loadbalance/weighted_balancer.go b/pkg/loadbalance/weighted_balancer.go new file mode 100644 index 0000000000..bcf5acc45d --- /dev/null +++ b/pkg/loadbalance/weighted_balancer.go @@ -0,0 +1,137 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "sync" + + "golang.org/x/sync/singleflight" + + "github.com/cloudwego/kitex/pkg/discovery" + "github.com/cloudwego/kitex/pkg/klog" +) + +const ( + lbKindRoundRobin = 1 + lbKindRandom = 2 +) + +type weightedBalancer struct { + kind int + pickerCache sync.Map + sfg singleflight.Group +} + +// NewWeightedBalancer creates a loadbalancer using weighted-round-robin algorithm. +func NewWeightedBalancer() Loadbalancer { + return NewWeightedRoundRobinBalancer() +} + +// NewWeightedRoundRobinBalancer creates a loadbalancer using weighted-round-robin algorithm. +func NewWeightedRoundRobinBalancer() Loadbalancer { + lb := &weightedBalancer{kind: lbKindRoundRobin} + return lb +} + +// NewWeightedRandomBalancer creates a loadbalancer using weighted-random algorithm. +func NewWeightedRandomBalancer() Loadbalancer { + lb := &weightedBalancer{kind: lbKindRandom} + return lb +} + +// GetPicker implements the Loadbalancer interface. +func (wb *weightedBalancer) GetPicker(e discovery.Result) Picker { + if !e.Cacheable { + picker := wb.createPicker(e) + return picker + } + + picker, ok := wb.pickerCache.Load(e.CacheKey) + if !ok { + picker, _, _ = wb.sfg.Do(e.CacheKey, func() (interface{}, error) { + p := wb.createPicker(e) + wb.pickerCache.Store(e.CacheKey, p) + return p, nil + }) + } + return picker.(Picker) +} + +func (wb *weightedBalancer) createPicker(e discovery.Result) (picker Picker) { + instances := make([]discovery.Instance, len(e.Instances)) // removed zero weight instances + weightSum := 0 + balance := true + cnt := 0 + for idx, instance := range e.Instances { + weight := instance.Weight() + if weight <= 0 { + klog.Warnf("KITEX: invalid weight, weight=%d instance=%s", weight, e.Instances[idx].Address()) + continue + } + weightSum += weight + instances[cnt] = instance + if cnt > 0 && instances[cnt-1].Weight() != weight { + balance = false + } + cnt++ + } + instances = instances[:cnt] + if len(instances) == 0 { + return new(DummyPicker) + } + + switch wb.kind { + case lbKindRoundRobin: + if balance { + picker = newRoundRobinPicker(instances) + } else { + picker = newWeightedRoundRobinPicker(instances) + } + default: // random + if balance { + picker = newRandomPicker(instances) + } else { + picker = newWeightedRandomPickerWithSum(instances, weightSum) + } + } + return picker +} + +// Rebalance implements the Rebalancer interface. +func (wb *weightedBalancer) Rebalance(change discovery.Change) { + if !change.Result.Cacheable { + return + } + wb.pickerCache.Store(change.Result.CacheKey, wb.createPicker(change.Result)) +} + +// Delete implements the Rebalancer interface. +func (wb *weightedBalancer) Delete(change discovery.Change) { + if !change.Result.Cacheable { + return + } + wb.pickerCache.Delete(change.Result.CacheKey) +} + +func (wb *weightedBalancer) Name() string { + switch wb.kind { + case lbKindRoundRobin: + return "weight_round_robin" + default: + return "weight_random" + } +} diff --git a/pkg/loadbalance/weighted_balancer_test.go b/pkg/loadbalance/weighted_balancer_test.go new file mode 100644 index 0000000000..f9882c4126 --- /dev/null +++ b/pkg/loadbalance/weighted_balancer_test.go @@ -0,0 +1,292 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "context" + "fmt" + "math" + "strconv" + "testing" + + "github.com/cloudwego/kitex/internal" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/discovery" +) + +type balancerTestcase struct { + Name string + factory func() Loadbalancer +} + +var balancerTestcases = []*balancerTestcase{ + {Name: "weight_round_robin", factory: NewWeightedRoundRobinBalancer}, + {Name: "weight_random", factory: NewWeightedRandomBalancer}, +} + +func TestWeightedBalancer_GetPicker(t *testing.T) { + for _, tc := range balancerTestcases { + t.Run(tc.Name, func(t *testing.T) { + balancer := tc.factory() + // nil + picker := balancer.GetPicker(discovery.Result{}) + test.Assert(t, picker != nil) + dp, ok := picker.(*DummyPicker) + test.Assert(t, ok && dp != nil) + + // invalid + picker = balancer.GetPicker(discovery.Result{ + Instances: []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", -10, nil), + discovery.NewInstance("tcp", "addr2", -20, nil), + }, + }) + test.Assert(t, picker != nil) + dp, ok = picker.(*DummyPicker) + test.Assert(t, ok && dp != nil) + + // one instance + insList := []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + } + picker = balancer.GetPicker(discovery.Result{ + Instances: insList, + }) + test.Assert(t, picker != nil) + + // multi instances + insList = []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", 20, nil), + discovery.NewInstance("tcp", "addr3", 30, nil), + } + picker = balancer.GetPicker(discovery.Result{ + Instances: insList, + }) + test.Assert(t, picker != nil) + + // balanced instances + insList = []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", 10, nil), + discovery.NewInstance("tcp", "addr3", 10, nil), + } + picker = balancer.GetPicker(discovery.Result{ + Instances: insList, + }) + test.Assert(t, picker != nil) + test.Assert(t, balancer.Name() == tc.Name) + }) + } +} + +func TestWeightedPicker_Next(t *testing.T) { + for _, tc := range balancerTestcases { + t.Run(tc.Name, func(t *testing.T) { + balancer := tc.factory() + ctx := context.Background() + // nil + picker := balancer.GetPicker(discovery.Result{}) + ins := picker.Next(ctx, nil) + test.Assert(t, ins == nil) + + // empty instance + picker = balancer.GetPicker(discovery.Result{ + Instances: make([]discovery.Instance, 0), + }) + ins = picker.Next(ctx, nil) + test.Assert(t, ins == nil) + + // one instance + insList := []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + } + for i := 0; i < 100; i++ { + picker := balancer.GetPicker(discovery.Result{ + Instances: insList, + }) + ins := picker.Next(ctx, nil) + test.Assert(t, ins.Weight() == 10) + } + + // multi instances, weightSum > 0 + insList = []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", 20, nil), + discovery.NewInstance("tcp", "addr3", 50, nil), + discovery.NewInstance("tcp", "addr4", 100, nil), + discovery.NewInstance("tcp", "addr5", 200, nil), + discovery.NewInstance("tcp", "addr6", 500, nil), + } + var weightSum int + for _, ins := range insList { + weight := ins.Weight() + weightSum += weight + } + n := 10000000 + pickedStat := map[int]int{} + for i := 0; i < n; i++ { + picker := balancer.GetPicker(discovery.Result{ + Instances: insList, + Cacheable: true, + }) + ins := picker.Next(ctx, nil) + weight := ins.Weight() + if pickedCnt, ok := pickedStat[weight]; ok { + pickedStat[weight] = pickedCnt + 1 + } else { + pickedStat[weight] = 1 + } + } + + for _, ins := range insList { + weight := ins.Weight() + expect := float64(weight) / float64(weightSum) * float64(n) + actual := float64(pickedStat[weight]) + delta := math.Abs(expect - actual) + test.Assertf(t, delta/expect < 0.01, "delta(%f)/expect(%f) = %f", delta, expect, delta/expect) + } + + // weightSum = 0 + insList = []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", -10, nil), + } + picker = balancer.GetPicker(discovery.Result{ + Instances: insList, + }) + test.Assert(t, picker.Next(ctx, nil) != nil) + + // weightSum < 0 + insList = []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", -20, nil), + } + picker = balancer.GetPicker(discovery.Result{ + Instances: insList, + }) + test.Assert(t, picker.Next(ctx, nil) != nil) + }) + } +} + +func TestWeightedPicker_NoMoreInstance(t *testing.T) { + for _, tc := range balancerTestcases { + t.Run(tc.Name, func(t *testing.T) { + balancer := tc.factory() + ctx := context.Background() + + // multi instances, weightSum > 0 + insList := []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", 20, nil), + discovery.NewInstance("tcp", "addr3", 50, nil), + discovery.NewInstance("tcp", "addr4", 100, nil), + discovery.NewInstance("tcp", "addr5", 200, nil), + discovery.NewInstance("tcp", "addr6", 500, nil), + } + + picker := balancer.GetPicker(discovery.Result{ + Instances: insList, + }) + + for i := 0; i < len(insList); i++ { + ins := picker.Next(ctx, nil) + test.Assert(t, ins != nil) + } + + ins := picker.Next(ctx, nil) + test.Assert(t, ins != nil) + }) + } +} + +func makeNinstances(n int) (res []discovery.Instance) { + for i := 0; i < n; i++ { + res = append(res, discovery.NewInstance("tcp", "addr"+strconv.Itoa(i), 10, nil)) + } + return +} + +func makeNWeightedInstances(n int) (res []discovery.Instance) { + for i := 0; i < n; i++ { + res = append(res, discovery.NewInstance("tcp", "addr"+strconv.Itoa(i), i, nil)) + } + return +} + +func BenchmarkWeightedPicker(b *testing.B) { + ctx := context.Background() + + for _, tc := range balancerTestcases { + b.Run(tc.Name, func(b *testing.B) { + balancer := tc.factory() + + n := 10 + for i := 0; i < 4; i++ { + b.Run(fmt.Sprintf("%dins", n), func(b *testing.B) { + inss := makeNinstances(n) + e := discovery.Result{ + Cacheable: true, + CacheKey: "test", + Instances: inss, + } + picker := balancer.GetPicker(e) + picker.Next(ctx, nil) + if r, ok := picker.(internal.Reusable); ok { + r.Recycle() + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + picker := balancer.GetPicker(e) + picker.Next(ctx, nil) + if r, ok := picker.(internal.Reusable); ok { + r.Recycle() + } + } + }) + n *= 10 + } + }) + } +} + +func BenchmarkGetPicker(b *testing.B) { + insList := genInstList() + for _, tc := range balancerTestcases { + b.Run(tc.Name, func(b *testing.B) { + balancer := tc.factory() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + balancer.GetPicker(discovery.Result{ + Instances: insList, + }) + } + }) + } +} + +func genInstList() []discovery.Instance { + n := 1000 + insList := make([]discovery.Instance, 0, n) + for i := 0; i < n; i++ { + insList = append(insList, discovery.NewInstance("tcp", "addr"+strconv.Itoa(i), 10, nil)) + } + return insList +} diff --git a/pkg/loadbalance/weighted_random.go b/pkg/loadbalance/weighted_random.go index 219bd6c014..5272555038 100644 --- a/pkg/loadbalance/weighted_random.go +++ b/pkg/loadbalance/weighted_random.go @@ -1,5 +1,5 @@ /* - * Copyright 2021 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,276 +18,48 @@ package loadbalance import ( "context" - "sync" "github.com/bytedance/gopkg/lang/fastrand" - "golang.org/x/sync/singleflight" "github.com/cloudwego/kitex/pkg/discovery" - "github.com/cloudwego/kitex/pkg/klog" ) -var weightedPickerPool, randomPickerPool sync.Pool - -func init() { - weightedPickerPool.New = newWeightedPicker - randomPickerPool.New = newRandomPicker -} - -type entry = int - -type weightedPicker struct { - immutableInstances []discovery.Instance - immutableEntries []entry - weightSum int - - copiedInstances []discovery.Instance - copiedEntries []entry - firstIndex int +type weightedRandomPicker struct { + instances []discovery.Instance + weightSum int } -func newWeightedPicker() interface{} { - return &weightedPicker{ - firstIndex: -1, +func newWeightedRandomPickerWithSum(instances []discovery.Instance, weightSum int) Picker { + return &weightedRandomPicker{ + instances: instances, + weightSum: weightSum, } } // Next implements the Picker interface. -func (wp *weightedPicker) Next(ctx context.Context, request interface{}) (ins discovery.Instance) { - if wp.firstIndex < 0 { - weight := fastrand.Intn(wp.weightSum) - for i := 0; i < len(wp.immutableEntries); i++ { - weight -= wp.immutableEntries[i] - if weight < 0 { - wp.firstIndex = i - break - } - } - return wp.immutableInstances[wp.firstIndex] - } - - if wp.copiedInstances == nil { - wp.copiedInstances = make([]discovery.Instance, len(wp.immutableInstances)-1) - copy(wp.copiedInstances, wp.immutableInstances[:wp.firstIndex]) - copy(wp.copiedInstances[wp.firstIndex:], wp.immutableInstances[wp.firstIndex+1:]) - - wp.copiedEntries = make([]entry, len(wp.immutableEntries)-1) - copy(wp.copiedEntries, wp.immutableEntries[:wp.firstIndex]) - copy(wp.copiedEntries[wp.firstIndex:], wp.immutableEntries[wp.firstIndex+1:]) - - wp.weightSum -= wp.immutableEntries[wp.firstIndex] - } - - n := len(wp.copiedInstances) - if n > 0 { - weight := fastrand.Intn(wp.weightSum) - for i := 0; i < len(wp.copiedEntries); i++ { - weight -= wp.copiedEntries[i] - if weight < 0 { - wp.weightSum -= wp.copiedEntries[i] - ins := wp.copiedInstances[i] - wp.copiedInstances[i] = wp.copiedInstances[n-1] - wp.copiedInstances = wp.copiedInstances[:n-1] - wp.copiedEntries[i] = wp.copiedEntries[n-1] - wp.copiedEntries = wp.copiedEntries[:n-1] - return ins - } +func (wp *weightedRandomPicker) Next(ctx context.Context, request interface{}) (ins discovery.Instance) { + weight := fastrand.Intn(wp.weightSum) + for i := 0; i < len(wp.instances); i++ { + weight -= wp.instances[i].Weight() + if weight < 0 { + return wp.instances[i] } } return nil } -func (wp *weightedPicker) zero() { - wp.immutableInstances = nil - wp.immutableEntries = nil - wp.weightSum = 0 - wp.copiedInstances = nil - wp.copiedEntries = nil - wp.firstIndex = -1 -} - -func (wp *weightedPicker) Recycle() { - wp.zero() - weightedPickerPool.Put(wp) -} - type randomPicker struct { - immutableInstances []discovery.Instance - firstIndex int - copiedInstances []discovery.Instance + instances []discovery.Instance } -func newRandomPicker() interface{} { +func newRandomPicker(instances []discovery.Instance) Picker { return &randomPicker{ - firstIndex: -1, + instances: instances, } } // Next implements the Picker interface. func (rp *randomPicker) Next(ctx context.Context, request interface{}) (ins discovery.Instance) { - if rp.firstIndex < 0 { - rp.firstIndex = fastrand.Intn(len(rp.immutableInstances)) - return rp.immutableInstances[rp.firstIndex] - } - - if rp.copiedInstances == nil { - rp.copiedInstances = make([]discovery.Instance, len(rp.immutableInstances)-1) - copy(rp.copiedInstances, rp.immutableInstances[:rp.firstIndex]) - copy(rp.copiedInstances[rp.firstIndex:], rp.immutableInstances[rp.firstIndex+1:]) - } - - n := len(rp.copiedInstances) - if n > 0 { - index := fastrand.Intn(n) - ins := rp.copiedInstances[index] - rp.copiedInstances[index] = rp.copiedInstances[n-1] - rp.copiedInstances = rp.copiedInstances[:n-1] - return ins - } - - return nil -} - -func (rp *randomPicker) zero() { - rp.copiedInstances = nil - rp.immutableInstances = nil - rp.firstIndex = -1 -} - -func (rp *randomPicker) Recycle() { - rp.zero() - randomPickerPool.Put(rp) -} - -type weightInfo struct { - balance bool - instances []discovery.Instance - entries []entry - weightSum int -} - -type weightedBalancer struct { - cachedWeightInfo sync.Map - sfg singleflight.Group -} - -// NewWeightedBalancer creates a loadbalancer using weighted-round-robin algorithm. -func NewWeightedBalancer() Loadbalancer { - lb := &weightedBalancer{} - return lb -} - -// GetPicker implements the Loadbalancer interface. -func (wb *weightedBalancer) GetPicker(e discovery.Result) Picker { - var w *weightInfo - if e.Cacheable { - wi, ok := wb.cachedWeightInfo.Load(e.CacheKey) - if !ok { - wi, _, _ = wb.sfg.Do(e.CacheKey, func() (interface{}, error) { - return wb.calcWeightInfo(e), nil - }) - wb.cachedWeightInfo.Store(e.CacheKey, wi) - } - w = wi.(*weightInfo) - } else { - w = wb.calcWeightInfo(e) - } - - if w.weightSum == 0 { - return new(DummyPicker) - } - - if w.balance { - picker := randomPickerPool.Get().(*randomPicker) - picker.immutableInstances = w.instances - picker.firstIndex = -1 - return picker - } - picker := weightedPickerPool.Get().(*weightedPicker) - picker.immutableEntries = w.entries - picker.weightSum = w.weightSum - picker.immutableInstances = w.instances - picker.firstIndex = -1 - return picker -} - -func (wb *weightedBalancer) calcWeightInfo(e discovery.Result) *weightInfo { - w := &weightInfo{ - balance: true, - instances: make([]discovery.Instance, len(e.Instances)), - entries: make([]entry, len(e.Instances)), - weightSum: 0, - } - - var cnt int - for idx := range e.Instances { - weight := e.Instances[idx].Weight() - if weight > 0 { - w.entries[cnt] = weight - w.instances[cnt] = e.Instances[idx] - if cnt > 0 && w.entries[cnt-1] != weight { - w.balance = false - } - w.weightSum += weight - cnt++ - } else { - klog.Warnf("KITEX: invalid weight, weight=%d instance=%s", weight, e.Instances[idx].Address()) - } - } - w.instances = w.instances[:cnt] - w.entries = w.entries[:cnt] - return w -} - -func (wb *weightedBalancer) calcWeightInfoExcludeInst(e discovery.Result, addr string) *weightInfo { - if len(e.Instances) <= 1 || addr == "" { - return wb.calcWeightInfo(e) - } - w := &weightInfo{ - balance: true, - entries: make([]entry, len(e.Instances)), - instances: make([]discovery.Instance, len(e.Instances)), - weightSum: 0, - } - var cnt int - for idx := range e.Instances { - weight := e.Instances[idx].Weight() - if weight > 0 { - if e.Instances[idx].Address().String() == addr { - continue - } - w.entries[cnt] = weight - w.instances[cnt] = e.Instances[idx] - w.weightSum += weight - if cnt > 0 && w.entries[cnt-1] != weight { - w.balance = false - } - cnt++ - } else { - klog.Warnf("KITEX: invalid weight: weight=%d instance=%s", weight, e.Instances[idx].Address()) - } - } - w.instances = w.instances[:cnt] - w.entries = w.entries[:cnt] - return w -} - -// Rebalance implements the Rebalancer interface. -func (wb *weightedBalancer) Rebalance(change discovery.Change) { - if !change.Result.Cacheable { - return - } - wb.cachedWeightInfo.Store(change.Result.CacheKey, wb.calcWeightInfo(change.Result)) -} - -// Delete implements the Rebalancer interface. -func (wb *weightedBalancer) Delete(change discovery.Change) { - if !change.Result.Cacheable { - return - } - wb.cachedWeightInfo.Delete(change.Result.CacheKey) -} - -func (wb *weightedBalancer) Name() string { - return "weight_random" + idx := fastrand.Intn(len(rp.instances)) + return rp.instances[idx] } diff --git a/pkg/loadbalance/weighted_random_test.go b/pkg/loadbalance/weighted_random_test.go deleted file mode 100644 index ca78ea07a3..0000000000 --- a/pkg/loadbalance/weighted_random_test.go +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Copyright 2021 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package loadbalance - -import ( - "context" - "fmt" - "math" - "strconv" - "testing" - - "github.com/cloudwego/kitex/internal" - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/discovery" -) - -func TestWeightedBalancer_GetPicker(t *testing.T) { - balancer := NewWeightedBalancer() - // nil - picker := balancer.GetPicker(discovery.Result{}) - test.Assert(t, picker != nil) - dp, ok := picker.(*DummyPicker) - test.Assert(t, ok && dp != nil) - - // invalid - picker = balancer.GetPicker(discovery.Result{ - Instances: []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", -10, nil), - discovery.NewInstance("tcp", "addr2", -20, nil), - }, - }) - test.Assert(t, picker != nil) - dp, ok = picker.(*DummyPicker) - test.Assert(t, ok && dp != nil) - - // one instance - insList := []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", 10, nil), - } - picker = balancer.GetPicker(discovery.Result{ - Instances: insList, - }) - test.Assert(t, picker != nil) - rp, ok := picker.(*randomPicker) - test.Assert(t, ok && rp != nil) - - // multi instances - insList = []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", 10, nil), - discovery.NewInstance("tcp", "addr2", 20, nil), - discovery.NewInstance("tcp", "addr3", 30, nil), - } - picker = balancer.GetPicker(discovery.Result{ - Instances: insList, - }) - test.Assert(t, picker != nil) - - // balanced instances - insList = []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", 10, nil), - discovery.NewInstance("tcp", "addr2", 10, nil), - discovery.NewInstance("tcp", "addr3", 10, nil), - } - picker = balancer.GetPicker(discovery.Result{ - Instances: insList, - }) - test.Assert(t, picker != nil) - rp, ok = picker.(*randomPicker) - test.Assert(t, ok && rp != nil) - test.Assert(t, balancer.Name() == "weight_random") -} - -func TestWeightedPicker_Next(t *testing.T) { - balancer := NewWeightedBalancer() - ctx := context.Background() - // nil - picker := balancer.GetPicker(discovery.Result{}) - ins := picker.Next(ctx, nil) - test.Assert(t, ins == nil) - - // empty instance - picker = balancer.GetPicker(discovery.Result{ - Instances: make([]discovery.Instance, 0), - }) - ins = picker.Next(ctx, nil) - test.Assert(t, ins == nil) - - // one instance - insList := []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", 10, nil), - } - for i := 0; i < 100; i++ { - picker := balancer.GetPicker(discovery.Result{ - Instances: insList, - }) - ins := picker.Next(ctx, nil) - test.Assert(t, ins.Weight() == 10) - } - - // multi instances, weightSum > 0 - insList = []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", 10, nil), - discovery.NewInstance("tcp", "addr2", 20, nil), - discovery.NewInstance("tcp", "addr3", 50, nil), - discovery.NewInstance("tcp", "addr4", 100, nil), - discovery.NewInstance("tcp", "addr5", 200, nil), - discovery.NewInstance("tcp", "addr6", 500, nil), - } - - var weightSum int - for _, ins := range insList { - weight := ins.Weight() - weightSum += weight - } - - n := 10000000 - pickedStat := map[int]int{} - for i := 0; i < n; i++ { - picker := balancer.GetPicker(discovery.Result{ - Instances: insList, - }) - ins := picker.Next(ctx, nil) - weight := ins.Weight() - if pickedCnt, ok := pickedStat[weight]; ok { - pickedStat[weight] = pickedCnt + 1 - } else { - pickedStat[weight] = 1 - } - } - - for _, ins := range insList { - weight := ins.Weight() - expect := float64(weight) / float64(weightSum) * float64(n) - actual := float64(pickedStat[weight]) - delta := math.Abs(expect - actual) - test.Assertf(t, delta/expect < 0.01, "delta(%f)/expect(%f) = %f", delta, expect, delta/expect) - } - - // weightSum = 0 - insList = []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", 10, nil), - discovery.NewInstance("tcp", "addr2", -10, nil), - } - picker = balancer.GetPicker(discovery.Result{ - Instances: insList, - }) - test.Assert(t, picker.Next(ctx, nil) != nil) - test.Assert(t, picker.Next(ctx, nil) == nil) - - // weightSum < 0 - insList = []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", 10, nil), - discovery.NewInstance("tcp", "addr2", -20, nil), - } - picker = balancer.GetPicker(discovery.Result{ - Instances: insList, - }) - test.Assert(t, picker.Next(ctx, nil) != nil) - test.Assert(t, picker.Next(ctx, nil) == nil) -} - -func TestWeightedPicker_NoMoreInstance(t *testing.T) { - balancer := NewWeightedBalancer() - ctx := context.Background() - - // multi instances, weightSum > 0 - insList := []discovery.Instance{ - discovery.NewInstance("tcp", "addr1", 10, nil), - discovery.NewInstance("tcp", "addr2", 20, nil), - discovery.NewInstance("tcp", "addr3", 50, nil), - discovery.NewInstance("tcp", "addr4", 100, nil), - discovery.NewInstance("tcp", "addr5", 200, nil), - discovery.NewInstance("tcp", "addr6", 500, nil), - } - - picker := balancer.GetPicker(discovery.Result{ - Instances: insList, - }) - - for i := 0; i < len(insList); i++ { - ins := picker.Next(ctx, nil) - test.Assert(t, ins != nil) - } - - ins := picker.Next(ctx, nil) - test.Assert(t, ins == nil) -} - -func makeNinstances(n int) (res []discovery.Instance) { - for i := 0; i < n; i++ { - res = append(res, discovery.NewInstance("tcp", "addr"+strconv.Itoa(i), 10, nil)) - } - return -} - -func makeNWeightedInstances(n int) (res []discovery.Instance) { - for i := 0; i < n; i++ { - res = append(res, discovery.NewInstance("tcp", "addr"+strconv.Itoa(i), i, nil)) - } - return -} - -func BenchmarkNewWeightedPicker(bb *testing.B) { - n := 10 - balancer := NewWeightedBalancer() - ctx := context.Background() - - for i := 0; i < 4; i++ { - bb.Run(fmt.Sprintf("%dins", n), func(b *testing.B) { - inss := makeNinstances(n) - e := discovery.Result{ - Cacheable: true, - CacheKey: "test", - Instances: inss, - } - picker := balancer.GetPicker(e) - picker.Next(ctx, nil) - picker.(internal.Reusable).Recycle() - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - picker := balancer.GetPicker(e) - picker.Next(ctx, nil) - if r, ok := picker.(internal.Reusable); ok { - r.Recycle() - } - } - }) - n *= 10 - } -} - -func BenchmarkGetPicker(b *testing.B) { - insList := genInstList() - balancer := &weightedBalancer{} - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - balancer.calcWeightInfo(discovery.Result{ - Instances: insList, - }) - } -} - -func BenchmarkGetPickerExcludeInst(b *testing.B) { - insList := genInstList() - balancer := &weightedBalancer{} - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - balancer.calcWeightInfoExcludeInst(discovery.Result{ - Instances: insList, - }, "addr5") - } -} - -func genInstList() []discovery.Instance { - n := 1000 - insList := make([]discovery.Instance, 0, n) - for i := 0; i < n; i++ { - insList = append(insList, discovery.NewInstance("tcp", "addr"+strconv.Itoa(i), 10, nil)) - } - return insList -} diff --git a/pkg/loadbalance/weighted_round_robin.go b/pkg/loadbalance/weighted_round_robin.go new file mode 100644 index 0000000000..25fa48c874 --- /dev/null +++ b/pkg/loadbalance/weighted_round_robin.go @@ -0,0 +1,153 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "context" + "sync" + + "github.com/bytedance/gopkg/lang/fastrand" + + "github.com/cloudwego/kitex/pkg/discovery" +) + +var ( + _ Picker = &WeightedRoundRobinPicker{} + _ Picker = &RoundRobinPicker{} +) + +const wrrVNodesBatchSize = 500 // it will calculate wrrVNodesBatchSize vnodes when insufficient + +type wrrNode struct { + discovery.Instance + current int +} + +func newWeightedRoundRobinPicker(instances []discovery.Instance) Picker { + wrrp := new(WeightedRoundRobinPicker) + wrrp.iterator = newRound() + + // shuffle nodes + wrrp.size = uint64(len(instances)) + wrrp.nodes = make([]*wrrNode, wrrp.size) + offset := fastrand.Uint64n(wrrp.size) + for idx := uint64(0); idx < wrrp.size; idx++ { + wrrp.nodes[idx] = &wrrNode{ + Instance: instances[(idx+offset)%wrrp.size], + current: 0, + } + } + + // init vnodes + totalWeight := 0 + for _, node := range instances { + totalWeight += node.Weight() + } + wrrp.vcapacity = uint64(totalWeight) + wrrp.vnodes = make([]discovery.Instance, wrrp.vcapacity) + wrrp.buildVirtualWrrNodes(wrrVNodesBatchSize) + return wrrp +} + +// WeightedRoundRobinPicker implement smooth weighted round-robin algorithm. +// Refer from https://github.com/phusion/nginx/commit/27e94984486058d73157038f7950a0a36ecc6e35 +type WeightedRoundRobinPicker struct { + nodes []*wrrNode + size uint64 + + iterator *round + vsize uint64 + vcapacity uint64 + vnodes []discovery.Instance + vlock sync.Mutex +} + +// Next implements the Picker interface. +func (wp *WeightedRoundRobinPicker) Next(ctx context.Context, request interface{}) (ins discovery.Instance) { + idx := wp.iterator.Next() % wp.vcapacity + // fast path + if wp.vnodes[idx] != nil { + return wp.vnodes[idx] + } + + // slow path + wp.vlock.Lock() + defer wp.vlock.Unlock() + if wp.vnodes[idx] != nil { // other goroutine has filled the vnodes + return wp.vnodes[idx] + } + vtarget := wp.vsize + wrrVNodesBatchSize + if idx > vtarget { + vtarget = idx + } + wp.buildVirtualWrrNodes(vtarget) + return wp.vnodes[idx] +} + +func (wp *WeightedRoundRobinPicker) buildVirtualWrrNodes(vtarget uint64) { + if vtarget > wp.vcapacity { + vtarget = wp.vcapacity + } + for i := wp.vsize; i < vtarget; i++ { + wp.vnodes[i] = nextWrrNode(wp.nodes).Instance + } + wp.vsize = vtarget +} + +func nextWrrNode(nodes []*wrrNode) (selected *wrrNode) { + maxCurrent := 0 + totalWeight := 0 + for _, node := range nodes { + node.current += node.Weight() + totalWeight += node.Weight() + if selected == nil || node.current > maxCurrent { + selected = node + maxCurrent = node.current + } + } + if selected == nil { + return nil + } + selected.current -= totalWeight + return selected +} + +// RoundRobinPicker . +type RoundRobinPicker struct { + size uint64 + instances []discovery.Instance + iterator *round +} + +func newRoundRobinPicker(instances []discovery.Instance) Picker { + size := uint64(len(instances)) + return &RoundRobinPicker{ + size: size, + instances: instances, + iterator: newRound(), + } +} + +// Next implements the Picker interface. +func (rp *RoundRobinPicker) Next(ctx context.Context, request interface{}) (ins discovery.Instance) { + if rp.size == 0 { + return nil + } + idx := rp.iterator.Next() % rp.size + ins = rp.instances[idx] + return ins +} diff --git a/pkg/loadbalance/weighted_round_robin_test.go b/pkg/loadbalance/weighted_round_robin_test.go new file mode 100644 index 0000000000..fe56f29bee --- /dev/null +++ b/pkg/loadbalance/weighted_round_robin_test.go @@ -0,0 +1,118 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "context" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/discovery" +) + +func TestRoundRobinPicker(t *testing.T) { + ctx := context.Background() + insList := []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", 10, nil), + discovery.NewInstance("tcp", "addr3", 10, nil), + } + roundRobinPickers := []Picker{ + newRoundRobinPicker(insList), + newWeightedRoundRobinPicker(insList), + } + for _, picker := range roundRobinPickers { + accessMap := map[string]int{} + for i := 0; i < len(insList); i++ { + node := picker.Next(ctx, nil) + accessMap[node.Address().String()] += node.Weight() + test.Assert(t, len(accessMap) == i+1, accessMap) + } + test.Assert(t, len(accessMap) == len(insList), accessMap) + for i := 0; i < len(insList)*9; i++ { + node := picker.Next(ctx, nil) + accessMap[node.Address().String()] += node.Weight() + } + test.Assert(t, len(accessMap) == len(insList), accessMap) + for _, v := range accessMap { + test.Assert(t, v == 10*10, v) + } + } +} + +func TestWeightedRoundRobinPicker(t *testing.T) { + ctx := context.Background() + insList := []discovery.Instance{ + discovery.NewInstance("tcp", "addr1", 10, nil), + discovery.NewInstance("tcp", "addr2", 20, nil), + discovery.NewInstance("tcp", "addr3", 30, nil), + } + picker := newWeightedRoundRobinPicker(insList) + accessMap := map[string]int{} + round := len(insList) * 1000 + doubleAccess := 0 + var lastNode discovery.Instance + for i := 0; i < round; i++ { + node := picker.Next(ctx, nil) + accessMap[node.Address().String()] += 1 + if lastNode != nil && lastNode.Address() == node.Address() { + doubleAccess++ + } + lastNode = node + } + t.Logf("doubleAccess: %d", doubleAccess) + test.Assert(t, len(accessMap) == len(insList), accessMap) + test.Assert(t, accessMap["addr1"] == round/6, accessMap) + test.Assert(t, accessMap["addr2"] == round/6*2, accessMap) + test.Assert(t, accessMap["addr3"] == round/6*3, accessMap) +} + +func TestNextWrrNode(t *testing.T) { + wrrNodes := []*wrrNode{ + {Instance: discovery.NewInstance("tcp", "addr1", 5, nil)}, + {Instance: discovery.NewInstance("tcp", "addr2", 1, nil)}, + {Instance: discovery.NewInstance("tcp", "addr3", 1, nil)}, + } + + // round 1 + node := nextWrrNode(wrrNodes) + test.Assert(t, node.Address() == wrrNodes[0].Address()) + test.Assert(t, wrrNodes[0].current == -2, wrrNodes[0]) + test.Assert(t, wrrNodes[1].current == 1, wrrNodes[1]) + test.Assert(t, wrrNodes[2].current == 1, wrrNodes[2]) + + // round 2 + node = nextWrrNode(wrrNodes) + test.Assert(t, node.Address() == wrrNodes[0].Address()) + test.Assert(t, wrrNodes[0].current == -4, wrrNodes[0]) + test.Assert(t, wrrNodes[1].current == 2, wrrNodes[1]) + test.Assert(t, wrrNodes[2].current == 2, wrrNodes[2]) + + // round 3 + node = nextWrrNode(wrrNodes) + test.Assert(t, node.Address() == wrrNodes[1].Address()) + test.Assert(t, wrrNodes[0].current == 1, wrrNodes[0]) + test.Assert(t, wrrNodes[1].current == -4, wrrNodes[1]) + test.Assert(t, wrrNodes[2].current == 3, wrrNodes[2]) + + // round 4 + node = nextWrrNode(wrrNodes) + test.Assert(t, node.Address() == wrrNodes[0].Address()) + test.Assert(t, wrrNodes[0].current == -1, wrrNodes[0]) + test.Assert(t, wrrNodes[1].current == -3, wrrNodes[1]) + test.Assert(t, wrrNodes[2].current == 4, wrrNodes[2]) +} diff --git a/pkg/profiler/profiler.go b/pkg/profiler/profiler.go index 043358f524..d52683b297 100644 --- a/pkg/profiler/profiler.go +++ b/pkg/profiler/profiler.go @@ -24,7 +24,7 @@ import ( "runtime/pprof" "sort" "strings" - "sync/atomic" + "sync" "time" "github.com/google/pprof/profile" @@ -68,25 +68,30 @@ func NewProfiler(processor Processor, interval, window time.Duration) *profiler processor = LogProcessor } return &profiler{ - stateTrigger: make(chan struct{}), - processor: processor, - interval: interval, - window: window, + stateCond: sync.NewCond(&sync.Mutex{}), + processor: processor, + interval: interval, + window: window, } } var _ Profiler = (*profiler)(nil) const ( - stateRunning = 0 - statePaused = 1 - stateStopped = 2 + // state changes: + // running => pausing => paused => resuming => running + // => stopped + stateRunning = 0 + statePausing = 1 + statePaused = 2 + stateResuming = 3 + stateStopped = 4 ) type profiler struct { - data bytes.Buffer // protobuf - stateTrigger chan struct{} - state int32 + data bytes.Buffer // protobuf + state int + stateCond *sync.Cond // settings processor Processor interval time.Duration // sleep time between every profiling window @@ -130,58 +135,45 @@ func (p *profiler) Prepare(ctx context.Context) context.Context { return context.WithValue(ctx, profilerContextKey{}, newProfilerContext(p)) } -// Stop the profiler analyse loop -func (p *profiler) Stop() { - for !atomic.CompareAndSwapInt32(&p.state, stateRunning, stateStopped) { - // wait for profiler to running state - time.Sleep(time.Second) - } - p.stateTrigger <- struct{}{} +// State return current profiler state +func (p *profiler) State() (state int) { + p.stateCond.L.Lock() + state = p.state + p.stateCond.L.Unlock() + return state } -func (p *profiler) State() int32 { - return atomic.LoadInt32(&p.state) -} - -// Pause the profiler -func (p *profiler) Pause() { - if atomic.CompareAndSwapInt32(&p.state, stateRunning, statePaused) { - // stop first then trigger - p.stopProfile() - p.stateTrigger <- struct{}{} +// Stop the profiler +func (p *profiler) Stop() { + if p.State() == stateStopped { + return } + // stateRunning => stateStopped + p.stateChange(stateRunning, stateStopped) } -func (p *profiler) waitResumed() { +// Pause the profiler. +// The profiler has been paused when Pause() return +func (p *profiler) Pause() { if p.State() == statePaused { - for p.State() != stateRunning { - // prevent if resumed twice - <-p.stateTrigger - } + return } + // stateRunning => statePausing + p.stateChange(stateRunning, statePausing) + // => statePaused + p.stateWait(statePaused) } -// Resume the profiler +// Resume the profiler. +// The profiler has been resumed when Resume() return func (p *profiler) Resume() { - if atomic.CompareAndSwapInt32(&p.state, statePaused, stateRunning) { - p.stateTrigger <- struct{}{} - } -} - -func timeoutOrTrigger(timer *time.Timer, trigger chan struct{}) { - select { - case <-timer.C: - // clear trigger if it's also active - select { - case <-trigger: - default: - } - case <-trigger: - // stop timer when trigger active - if !timer.Stop() { - <-timer.C - } + if p.State() == stateRunning { + return } + // statePaused => stateResuming + p.stateChange(statePaused, stateResuming) + // => stateRunning + p.stateWait(stateRunning) } // Run start analyse the pprof data with interval and window settings @@ -189,18 +181,27 @@ func (p *profiler) Run(ctx context.Context) (err error) { var profiles []*TagsProfile timer := time.NewTimer(0) for { - // wait for an internal time to reduce the cost of profiling - if p.interval > 0 { - timer.Reset(p.interval) - timeoutOrTrigger(timer, p.stateTrigger) - } - switch p.State() { - case stateStopped: - return nil - case statePaused: + // check state + state := p.State() + switch state { + case stateRunning: // do nothing + case statePausing: // pause the loop + p.stateChange(statePausing, statePaused) // wake up Pause() klog.Info("KITEX: profiler paused") - p.waitResumed() + p.stateChange(stateResuming, stateRunning) // wake up Resume() klog.Info("KITEX: profiler resumed") + continue + case statePaused, stateResuming: // actually, no such case + continue + case stateStopped: // end the loop + klog.Info("KITEX: profiler stopped") + return nil + } + + // wait for an interval time to reduce the cost of profiling + if p.interval > 0 { + timer.Reset(p.interval) + <-timer.C } // start profiler @@ -215,17 +216,7 @@ func (p *profiler) Run(ctx context.Context) (err error) { // wait for a window time to collect pprof data if p.window > 0 { timer.Reset(p.window) - timeoutOrTrigger(timer, p.stateTrigger) - } - switch p.State() { - case stateStopped: - p.stopProfile() - return nil - case statePaused: - klog.Info("KITEX: profiler paused") - p.waitResumed() - klog.Info("KITEX: profiler resumed") - continue + <-timer.C } // stop profiler @@ -277,6 +268,24 @@ func (p *profiler) Lookup(ctx context.Context, key string) (string, bool) { return pprof.Label(ctx, key) } +func (p *profiler) stateChange(from, to int) { + p.stateCond.L.Lock() + for p.state != from { // wait state to from first + p.stateCond.Wait() + } + p.state = to + p.stateCond.L.Unlock() + p.stateCond.Broadcast() +} + +func (p *profiler) stateWait(to int) { + p.stateCond.L.Lock() + for p.state != to { + p.stateCond.Wait() + } + p.stateCond.L.Unlock() +} + func (p *profiler) startProfile() error { p.data.Reset() return pprof.StartCPUProfile(&p.data) diff --git a/pkg/profiler/profiler_test.go b/pkg/profiler/profiler_test.go index cb0cf2da25..5ca260283e 100644 --- a/pkg/profiler/profiler_test.go +++ b/pkg/profiler/profiler_test.go @@ -75,8 +75,8 @@ func TestProfiler(t *testing.T) { } func TestProfilerPaused(t *testing.T) { - interval := time.Millisecond * 100 - window := time.Millisecond * 100 + interval := time.Millisecond * 50 + window := time.Millisecond * 50 var count int32 p := NewProfiler(func(profiles []*TagsProfile) error { atomic.AddInt32(&count, 1) @@ -89,7 +89,7 @@ func TestProfilerPaused(t *testing.T) { test.Assert(t, err == nil, err) close(stopCh) }() - time.Sleep(interval / 2) + time.Sleep(interval) var data bytes.Buffer for i := 0; i < 5; i++ { @@ -101,7 +101,6 @@ func TestProfilerPaused(t *testing.T) { pprof.StopCPUProfile() p.Resume() p.Resume() // resume twice by mistake - time.Sleep(interval) } for atomic.LoadInt32(&count) > 5 { // wait for processor finished @@ -109,6 +108,7 @@ func TestProfilerPaused(t *testing.T) { } p.Stop() + p.Stop() // stop twice by mistake <-stopCh } diff --git a/pkg/remote/bound/transmeta_bound.go b/pkg/remote/bound/transmeta_bound.go index eebeb41b44..beb2ef9ee6 100644 --- a/pkg/remote/bound/transmeta_bound.go +++ b/pkg/remote/bound/transmeta_bound.go @@ -20,6 +20,8 @@ import ( "context" "net" + "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/kitex/pkg/consts" "github.com/cloudwego/kitex/pkg/remote" ) @@ -62,6 +64,11 @@ func (h *transMetaHandler) OnMessage(ctx context.Context, args, result remote.Me if isServer && result.MessageType() != remote.Exception { // Pass through method name using ctx, the method name will be used as from method in the client. ctx = context.WithValue(ctx, consts.CtxKeyMethod, msg.RPCInfo().To().Method()) + // TransferForward converts transient values to transient-upstream values and filters out original transient-upstream values. + // It should be used before the context is passing from server to client. + // reference https://github.com/bytedance/gopkg/tree/main/cloud/metainfo + // Notice, it should be after ReadMeta(). + ctx = metainfo.TransferForward(ctx) } return ctx, nil } diff --git a/pkg/remote/bound/transmeta_bound_test.go b/pkg/remote/bound/transmeta_bound_test.go index 5a36a5780f..5eaf335870 100644 --- a/pkg/remote/bound/transmeta_bound_test.go +++ b/pkg/remote/bound/transmeta_bound_test.go @@ -21,6 +21,7 @@ import ( "errors" "testing" + "github.com/bytedance/gopkg/cloud/metainfo" "github.com/golang/mock/gomock" mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" @@ -30,6 +31,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote/trans/invoke" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" + "github.com/cloudwego/kitex/pkg/transmeta" ) // TestNewTransMetaHandler test NewTransMetaHandler function and assert the result not nil. @@ -204,6 +206,51 @@ func TestTransMetaHandlerOnMessage(t *testing.T) { test.Assert(t, err == nil) test.Assert(t, ctx != nil && args.RPCInfo().To().Method() == ctx.Value(consts.CtxKeyMethod)) }) + + t.Run("Test metainfo transient key", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + metaHandler1 := transmeta.MetainfoServerHandler + metaHandler2 := mocksremote.NewMockMetaHandler(ctrl) + + ink := rpcinfo.NewInvocation("", "mock") + to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{}, "") + ri := rpcinfo.NewRPCInfo(nil, to, ink, nil, nil) + + args := remote.NewMessage( + nil, nil, ri, remote.Call, remote.Server) + result := remote.NewMessage( + nil, nil, nil, remote.Reply, remote.Server) + ctx := context.Background() + + tk1, tv1 := "tk1", "tv1" + tk2, tv2 := "tk2", "tv2" + args.TransInfo().PutTransStrInfo(map[string]string{metainfo.PrefixTransient + tk1: tv1}) + metaHandler2.EXPECT().ReadMeta(gomock.Any(), args).DoAndReturn(func(ctx context.Context, msg remote.Message) (context.Context, error) { + ctx = metainfo.SetMetaInfoFromMap(ctx, map[string]string{metainfo.PrefixTransient + tk2: tv2}) + return ctx, nil + }).Times(1) + mhs := []remote.MetaHandler{ + metaHandler1, metaHandler2, + } + + handler := NewTransMetaHandler(mhs) + + test.Assert(t, handler != nil) + ctx, err := handler.OnMessage(ctx, args, result) + test.Assert(t, err == nil) + v, ok := metainfo.GetValue(ctx, tk1) + test.Assert(t, ok) + test.Assert(t, v == tv1) + v, ok = metainfo.GetValue(ctx, tk2) + test.Assert(t, ok) + test.Assert(t, v == tv2) + + kvs := make(map[string]string) + metainfo.SaveMetaInfoToMap(ctx, kvs) + test.Assert(t, len(kvs) == 0) + }) } // TestGetValidMsg test getValidMsg function with message of server side and client side. diff --git a/pkg/remote/codec/header_codec.go b/pkg/remote/codec/header_codec.go index 21be6f403b..97c24dcf90 100644 --- a/pkg/remote/codec/header_codec.go +++ b/pkg/remote/codec/header_codec.go @@ -196,7 +196,7 @@ func (t ttHeader) decode(ctx context.Context, message remote.Message, in remote. } if err := readKVInfo(hdIdx, headerInfo, message); err != nil { - return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader read kv info failed, %s", err.Error())) + return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader read kv info failed, %s, headerInfo=%#x", err.Error(), headerInfo)) } fillBasicInfoOfTTHeader(message) @@ -208,7 +208,25 @@ func writeKVInfo(writtenSize int, message remote.Message, out remote.ByteBuffer) writeSize = writtenSize tm := message.TransInfo() // str kv info - strKVSize := len(tm.TransStrInfo()) + strKVMap := tm.TransStrInfo() + strKVSize := len(strKVMap) + // write gdpr token into InfoIDACLToken + // supplementary doc: https://www.cloudwego.io/docs/kitex/reference/transport_protocol_ttheader/ + if gdprToken, ok := strKVMap[transmeta.GDPRToken]; ok { + strKVSize-- + // INFO ID TYPE(u8) + if err = WriteByte(byte(InfoIDACLToken), out); err != nil { + return writeSize, err + } + writeSize += 1 + + wLen, err := WriteString2BLen(gdprToken, out) + if err != nil { + return writeSize, err + } + writeSize += wLen + } + if strKVSize > 0 { // INFO ID TYPE(u8) + NUM HEADERS(u16) if err = WriteByte(byte(InfoIDKeyValue), out); err != nil { @@ -218,7 +236,10 @@ func writeKVInfo(writtenSize int, message remote.Message, out remote.ByteBuffer) return writeSize, err } writeSize += 3 - for key, val := range tm.TransStrInfo() { + for key, val := range strKVMap { + if key == transmeta.GDPRToken { + continue + } keyWLen, err := WriteString2BLen(key, out) if err != nil { return writeSize, err @@ -253,6 +274,7 @@ func writeKVInfo(writtenSize int, message remote.Message, out remote.ByteBuffer) writeSize = writeSize + 2 + valWLen } } + // padding = (4 - headerInfoSize%4) % 4 padding := (4 - writeSize%4) % 4 paddingBuf, err := out.Malloc(padding) @@ -294,8 +316,7 @@ func readKVInfo(idx int, buf []byte, message remote.Message) error { return err } case InfoIDACLToken: - err = skipACLToken(&idx, buf) - if err != nil { + if err := readACLToken(&idx, buf, strInfo); err != nil { return err } default: @@ -355,13 +376,14 @@ func readStrKVInfo(idx *int, buf []byte, info map[string]string) (has bool, err return true, nil } -// skipACLToken SDK don't need acl token, just skip it -func skipACLToken(idx *int, buf []byte) error { - _, n, err := ReadString2BLen(buf, *idx) +// readACLToken reads acl token +func readACLToken(idx *int, buf []byte, info map[string]string) error { + val, n, err := ReadString2BLen(buf, *idx) *idx += n if err != nil { return fmt.Errorf("error reading acl token: %s", err.Error()) } + info[transmeta.GDPRToken] = val return nil } diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index 7bfce584c6..10d0903453 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -22,6 +22,8 @@ import ( "net" "testing" + "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/kitex/internal/mocks" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/discovery" @@ -30,6 +32,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote/transmeta" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" + tm "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/transport" ) @@ -88,6 +91,73 @@ func TestTTHeaderCodecWithTransInfo(t *testing.T) { test.Assert(t, flag == uint16(HeaderFlagSupportOutOfOrder)) } +func TestTTHeaderCodecWithTransInfoWithGDPRToken(t *testing.T) { + ctx := context.Background() + intKVInfo := prepareIntKVInfo() + strKVInfo := prepareStrKVInfoWithGDPRToken() + sendMsg := initClientSendMsg(transport.TTHeader) + sendMsg.TransInfo().PutTransIntInfo(intKVInfo) + sendMsg.TransInfo().PutTransStrInfo(strKVInfo) + sendMsg.Tags()[HeaderFlagsKey] = HeaderFlagSupportOutOfOrder + + // encode + out := remote.NewWriterBuffer(256) + totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, out) + binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen)) + test.Assert(t, err == nil, err) + + // decode + recvMsg := initServerRecvMsg() + buf, err := out.Bytes() + test.Assert(t, err == nil, err) + in := remote.NewReaderBuffer(buf) + err = ttHeaderCodec.decode(ctx, recvMsg, in) + test.Assert(t, err == nil, err) + test.Assert(t, recvMsg.PayloadLen() == mockPayloadLen, recvMsg.PayloadLen()) + + intKVInfoRecv := recvMsg.TransInfo().TransIntInfo() + strKVInfoRecv := recvMsg.TransInfo().TransStrInfo() + test.DeepEqual(t, intKVInfoRecv, intKVInfo) + test.DeepEqual(t, strKVInfoRecv, strKVInfo) + flag := recvMsg.Tags()[HeaderFlagsKey] + test.Assert(t, flag != nil) + test.Assert(t, flag == uint16(HeaderFlagSupportOutOfOrder)) +} + +func TestTTHeaderCodecWithTransInfoFromMetaInfoGDPRToken(t *testing.T) { + ctx := context.Background() + intKVInfo := prepareIntKVInfo() + ctx = metainfo.WithValue(ctx, "gdpr-token", "test token") + sendMsg := initClientSendMsg(transport.TTHeader) + sendMsg.TransInfo().PutTransIntInfo(intKVInfo) + ctx, err := tm.MetainfoClientHandler.WriteMeta(ctx, sendMsg) + test.Assert(t, err == nil) + sendMsg.Tags()[HeaderFlagsKey] = HeaderFlagSupportOutOfOrder + + // encode + out := remote.NewWriterBuffer(256) + totalLenField, err := ttHeaderCodec.encode(ctx, sendMsg, out) + binary.BigEndian.PutUint32(totalLenField, uint32(out.MallocLen()-Size32+mockPayloadLen)) + test.Assert(t, err == nil, err) + + // decode + recvMsg := initServerRecvMsg() + buf, err := out.Bytes() + test.Assert(t, err == nil, err) + in := remote.NewReaderBuffer(buf) + err = ttHeaderCodec.decode(ctx, recvMsg, in) + test.Assert(t, err == nil, err) + test.Assert(t, recvMsg.PayloadLen() == mockPayloadLen, recvMsg.PayloadLen()) + + intKVInfoRecv := recvMsg.TransInfo().TransIntInfo() + strKVInfoRecv := recvMsg.TransInfo().TransStrInfo() + test.DeepEqual(t, intKVInfoRecv, intKVInfo) + test.DeepEqual(t, strKVInfoRecv, map[string]string{transmeta.GDPRToken: "test token"}) + flag := recvMsg.Tags()[HeaderFlagsKey] + test.Assert(t, flag != nil) + test.Assert(t, flag == uint16(HeaderFlagSupportOutOfOrder)) +} + func TestFillBasicInfoOfTTHeader(t *testing.T) { ctx := context.Background() mockAddr := "mock address" @@ -301,6 +371,14 @@ func prepareStrKVInfo() map[string]string { return kvInfo } +func prepareStrKVInfoWithGDPRToken() map[string]string { + kvInfo := map[string]string{ + transmeta.GDPRToken: "mockToken", + transmeta.HeaderTransRemoteAddr: "mockRemoteAddr", + } + return kvInfo +} + // // TODO 是否提供buf.writeInt8/16/32方法,否则得先计算,然后malloc,最后write,待确认频繁malloc是否有影响 // 暂时不删除,测试一次malloc, 和多次malloc差异 // diff --git a/pkg/remote/codec/protobuf/grpc.go b/pkg/remote/codec/protobuf/grpc.go index a7f03b4f01..c8445a9d63 100644 --- a/pkg/remote/codec/protobuf/grpc.go +++ b/pkg/remote/codec/protobuf/grpc.go @@ -28,6 +28,8 @@ import ( "github.com/cloudwego/kitex/pkg/remote" ) +const dataFrameHeaderLen = 5 + // gogoproto generate type marshaler interface { MarshalTo(data []byte) (n int, err error) @@ -46,6 +48,12 @@ func NewGRPCCodec() remote.Codec { return new(grpcCodec) } +func mallocWithFirstByteZeroed(size int) []byte { + data := mcache.Malloc(size) + data[0] = 0 // compressed flag = false + return data +} + func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) { writer, ok := out.(remote.FrameWrite) if !ok { @@ -57,18 +65,18 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo case fastpb.Writer: // TODO: reuse data buffer when we can free it safely size := t.Size() - data = mcache.Malloc(size + 5) - t.FastWrite(data[5:]) - binary.BigEndian.PutUint32(data[1:5], uint32(size)) + data = mallocWithFirstByteZeroed(size + dataFrameHeaderLen) + t.FastWrite(data[dataFrameHeaderLen:]) + binary.BigEndian.PutUint32(data[1:dataFrameHeaderLen], uint32(size)) return writer.WriteData(data) case marshaler: // TODO: reuse data buffer when we can free it safely size := t.Size() - data = mcache.Malloc(size + 5) - if _, err = t.MarshalTo(data[5:]); err != nil { + data = mallocWithFirstByteZeroed(size + dataFrameHeaderLen) + if _, err = t.MarshalTo(data[dataFrameHeaderLen:]); err != nil { return err } - binary.BigEndian.PutUint32(data[1:5], uint32(size)) + binary.BigEndian.PutUint32(data[1:dataFrameHeaderLen], uint32(size)) return writer.WriteData(data) case protobufV2MsgCodec: data, err = t.XXX_Marshal(nil, true) @@ -100,10 +108,18 @@ func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remot } message.SetPayloadLen(dLen) data := message.Data() + if t, ok := data.(fastpb.Reader); ok { + if len(d) == 0 { + // if all fields of a struct is default value, data will be nil + // In the implementation of fastpb, if data is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected. + // So, when data is nil, use default protobuf unmarshal method to decode the struct. + // todo: fix fastpb + } else { + _, err = fastpb.ReadMessage(d, fastpb.SkipTypeCheck, t) + return err + } + } switch t := data.(type) { - case fastpb.Reader: - _, err = fastpb.ReadMessage(d, fastpb.SkipTypeCheck, t) - return err case protobufV2MsgCodec: return t.XXX_Unmarshal(d) case proto.Message: diff --git a/pkg/remote/codec/protobuf/grpc_test.go b/pkg/remote/codec/protobuf/grpc_test.go new file mode 100644 index 0000000000..1291e5b9fe --- /dev/null +++ b/pkg/remote/codec/protobuf/grpc_test.go @@ -0,0 +1,70 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protobuf + +import ( + "testing" + + "github.com/bytedance/gopkg/lang/mcache" +) + +func Test_mallocWithFirstByteZeroed(t *testing.T) { + type args struct { + size int + data []byte + } + tests := []struct { + name string + args args + want byte + }{ + { + name: "test_with_no_data", + args: args{ + size: 4, + data: nil, + }, + want: 0, + }, + { + name: "test_with_zeroed_data", + args: args{ + size: 8, + data: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + }, + want: 0, + }, + { + name: "test_with_non_zeroed_data", + args: args{ + size: 16, + data: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + want: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.data != nil { + buf := mcache.Malloc(tt.args.size) + copy(buf, tt.args.data) + mcache.Free(buf) + } + if got := mallocWithFirstByteZeroed(tt.args.size); got[0] != tt.want { + t.Errorf("mallocWithFirstByteZeroed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/remote/codec/protobuf/protobuf.go b/pkg/remote/codec/protobuf/protobuf.go index 1ec4a6dd63..817af58c86 100644 --- a/pkg/remote/codec/protobuf/protobuf.go +++ b/pkg/remote/codec/protobuf/protobuf.go @@ -155,11 +155,18 @@ func (c protobufCodec) Unmarshal(ctx context.Context, message remote.Message, in data := message.Data() // fast read if msg, ok := data.(fastpb.Reader); ok { - _, err := fastpb.ReadMessage(actualMsgBuf, fastpb.SkipTypeCheck, msg) - if err != nil { - return remote.NewTransErrorWithMsg(remote.ProtocolError, err.Error()) + if len(actualMsgBuf) == 0 { + // if all fields of a struct is default value, actualMsgLen will be zero and actualMsgBuf will be nil + // In the implementation of fastpb, if actualMsgBuf is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected. + // So, when actualMsgBuf is nil, use default protobuf unmarshal method to decode the struct. + // todo: fix fastpb + } else { + _, err := fastpb.ReadMessage(actualMsgBuf, fastpb.SkipTypeCheck, msg) + if err != nil { + return remote.NewTransErrorWithMsg(remote.ProtocolError, err.Error()) + } + return nil } - return nil } msg, ok := data.(protobufMsgCodec) if !ok { diff --git a/pkg/remote/codec/thrift/thrift_frugal_amd64.go b/pkg/remote/codec/thrift/thrift_frugal_amd64.go index b46f52aa22..1a50f35e19 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_amd64.go +++ b/pkg/remote/codec/thrift/thrift_frugal_amd64.go @@ -1,5 +1,5 @@ -//go:build amd64 && !windows && go1.15 -// +build amd64,!windows,go1.15 +//go:build amd64 && !windows && go1.16 +// +build amd64,!windows,go1.16 /* * Copyright 2021 CloudWeGo Authors diff --git a/pkg/remote/codec/thrift/thrift_frugal_amd64_test.go b/pkg/remote/codec/thrift/thrift_frugal_amd64_test.go index 40032aacb0..4661d4dc65 100644 --- a/pkg/remote/codec/thrift/thrift_frugal_amd64_test.go +++ b/pkg/remote/codec/thrift/thrift_frugal_amd64_test.go @@ -1,5 +1,5 @@ -//go:build amd64 && !windows && go1.15 -// +build amd64,!windows,go1.15 +//go:build amd64 && !windows && go1.16 +// +build amd64,!windows,go1.16 /* * Copyright 2021 CloudWeGo Authors diff --git a/pkg/remote/codec/thrift/thrift_others.go b/pkg/remote/codec/thrift/thrift_others.go index 01dd9793d7..839ac126f5 100644 --- a/pkg/remote/codec/thrift/thrift_others.go +++ b/pkg/remote/codec/thrift/thrift_others.go @@ -1,5 +1,5 @@ -//go:build !amd64 || windows || !go1.15 -// +build !amd64 windows !go1.15 +//go:build !amd64 || windows || !go1.16 +// +build !amd64 windows !go1.16 /* * Copyright 2021 CloudWeGo Authors diff --git a/pkg/remote/connpool/long_pool.go b/pkg/remote/connpool/long_pool.go index b77a82d3c8..cdc36021fa 100644 --- a/pkg/remote/connpool/long_pool.go +++ b/pkg/remote/connpool/long_pool.go @@ -19,8 +19,10 @@ package connpool import ( "context" + "fmt" "net" "sync" + "sync/atomic" "time" "github.com/cloudwego/kitex/pkg/connpool" @@ -32,8 +34,28 @@ import ( var ( _ net.Conn = &longConn{} _ remote.LongConnPool = &LongPool{} + + // global shared tickers for different LongPool + sharedTickers sync.Map +) + +const ( + configDumpKey = "idle_config" ) +func getSharedTicker(p *LongPool, refreshInterval time.Duration) *utils.SharedTicker { + sti, ok := sharedTickers.Load(refreshInterval) + if ok { + st := sti.(*utils.SharedTicker) + st.Add(p) + return st + } + sti, _ = sharedTickers.LoadOrStore(refreshInterval, utils.NewSharedTicker(refreshInterval)) + st := sti.(*utils.SharedTicker) + st.Add(p) + return st +} + // netAddr implements the net.Addr interface and comparability. type netAddr struct { network string @@ -283,20 +305,22 @@ func NewLongPool(serviceName string, idlConfig connpool.IdleConfig) *LongPool { idlConfig.MaxIdleTimeout, limit) }, - closeCh: make(chan struct{}), + idleConfig: idlConfig, } - - go lp.Evict(idlConfig.MaxIdleTimeout) + // add this long pool into the sharedTicker + lp.sharedTicker = getSharedTicker(lp, idlConfig.MaxIdleTimeout) return lp } // LongPool manages a pool of long connections. type LongPool struct { - reporter Reporter - peerMap sync.Map - newPeer func(net.Addr) *peer - globalIdle *utils.MaxCounter - closeCh chan struct{} + reporter Reporter + peerMap sync.Map + newPeer func(net.Addr) *peer + globalIdle *utils.MaxCounter + idleConfig connpool.IdleConfig + sharedTicker *utils.SharedTicker + closed int32 // active: 0, closed: 1 } // Get pick or generate a net.Conn and return @@ -345,6 +369,7 @@ func (lp *LongPool) Clean(network, address string) { // Dump is used to dump current long pool info when needed, like debug query. func (lp *LongPool) Dump() interface{} { m := make(map[string]interface{}) + m[configDumpKey] = lp.idleConfig lp.peerMap.Range(func(key, value interface{}) bool { t := value.(*peer).pool.Dump() m[key.(netAddr).String()] = t @@ -355,17 +380,18 @@ func (lp *LongPool) Dump() interface{} { // Close releases all peers in the pool, it is executed when client is closed. func (lp *LongPool) Close() error { - select { - case <-lp.closeCh: - default: - close(lp.closeCh) + if !atomic.CompareAndSwapInt32(&lp.closed, 0, 1) { + return fmt.Errorf("long pool is already closed") } + // close all peers lp.peerMap.Range(func(addr, value interface{}) bool { lp.peerMap.Delete(addr) v := value.(*peer) v.Close() return true }) + // remove from the shared ticker + lp.sharedTicker.Delete(lp) return nil } @@ -381,23 +407,22 @@ func (lp *LongPool) WarmUp(eh warmup.ErrorHandling, wuo *warmup.PoolOption, co r } // Evict cleanups the idle connections in peers. -func (lp *LongPool) Evict(frequency time.Duration) { - t := time.NewTicker(frequency) - defer t.Stop() - for { - select { - case <-t.C: - lp.peerMap.Range(func(key, value interface{}) bool { - p := value.(*peer) - p.Evict() - return true - }) - case <-lp.closeCh: - return - } +func (lp *LongPool) Evict() { + if atomic.LoadInt32(&lp.closed) == 0 { + // Evict idle connections + lp.peerMap.Range(func(key, value interface{}) bool { + p := value.(*peer) + p.Evict() + return true + }) } } +// Tick implements the interface utils.TickerTask. +func (lp *LongPool) Tick() { + lp.Evict() +} + // getPeer gets a peer from the pool based on the addr, or create a new one if not exist. func (lp *LongPool) getPeer(addr netAddr) *peer { p, ok := lp.peerMap.Load(addr) diff --git a/pkg/remote/connpool/long_pool_test.go b/pkg/remote/connpool/long_pool_test.go index 496f7c88c8..e3f239613d 100644 --- a/pkg/remote/connpool/long_pool_test.go +++ b/pkg/remote/connpool/long_pool_test.go @@ -745,6 +745,28 @@ func TestConnPoolClose(t *testing.T) { test.Assert(t, connCount == 0) } +func TestClosePoolAndSharedTicker(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var ( + poolNum = 10 + idleTimeoutUnit = 111 * time.Millisecond + pools = make([]*LongPool, poolNum) + ) + // add new pool with different idleTimeout, increasing the number of shared ticker + for i := 0; i < poolNum; i++ { + pools[i] = newLongPoolForTest(0, 2, 3, time.Duration(i+1)*idleTimeoutUnit) + } + // close + for i := 0; i < poolNum; i++ { + pools[i].Close() + // should be removed from shardTickers + _, ok := sharedTickers.Load(pools[i]) + test.Assert(t, !ok) + } +} + func TestLongConnPoolPutUnknownConnection(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/pkg/remote/trans/default_server_handler.go b/pkg/remote/trans/default_server_handler.go index b0f341ef1c..6150110f7d 100644 --- a/pkg/remote/trans/default_server_handler.go +++ b/pkg/remote/trans/default_server_handler.go @@ -19,6 +19,7 @@ package trans import ( "context" "errors" + "fmt" "net" "runtime/debug" @@ -110,27 +111,30 @@ func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, recvMsg remot } // OnRead implements the remote.ServerTransHandler interface. -func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { +// The connection should be closed after returning error. +func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) { ri := rpcinfo.GetRPCInfo(ctx) t.ext.SetReadTimeout(ctx, conn, ri.Config(), remote.Server) - var err error - var closeConn bool var recvMsg remote.Message var sendMsg remote.Message + closeConnOutsideIfErr := true defer func() { panicErr := recover() + var wrapErr error if panicErr != nil { - closeConn = true + stack := string(debug.Stack()) if conn != nil { ri := rpcinfo.GetRPCInfo(ctx) rService, rAddr := getRemoteInfo(ri, conn) - klog.CtxErrorf(ctx, "KITEX: panic happened, close conn, remoteAddress=%s, remoteService=%s, error=%v\nstack=%s", rAddr, rService, panicErr, string(debug.Stack())) + klog.CtxErrorf(ctx, "KITEX: panic happened, remoteAddress=%s, remoteService=%s, error=%v\nstack=%s", rAddr, rService, panicErr, stack) } else { - klog.CtxErrorf(ctx, "KITEX: panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack())) + klog.CtxErrorf(ctx, "KITEX: panic happened, error=%v\nstack=%s", panicErr, stack) + } + if err != nil { + wrapErr = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[happened in OnRead] %s, last error=%s", panicErr, err.Error()), stack) + } else { + wrapErr = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[happened in OnRead] %s", panicErr), stack) } - } - if closeConn && conn != nil { - conn.Close() } t.finishTracer(ctx, ri, err, panicErr) t.finishProfiler(ctx) @@ -138,6 +142,12 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { remote.RecycleMessage(sendMsg) // reset rpcinfo t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr()) + if wrapErr != nil { + err = wrapErr + } + if err != nil && !closeConnOutsideIfErr { + err = nil + } }() ctx = t.startTracer(ctx, ri) ctx = t.startProfiler(ctx) @@ -145,20 +155,18 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { recvMsg.SetPayloadCodec(t.opt.PayloadCodec) ctx, err = t.Read(ctx, conn, recvMsg) if err != nil { - closeConn = true t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true) t.OnError(ctx, err, conn) - return nil + return err } var methodInfo serviceinfo.MethodInfo if methodInfo, err = GetMethodInfo(ri, t.svcInfo); err != nil { // it won't be err, because the method has been checked in decode, err check here just do defensive inspection - closeConn = true t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true) // for proxy case, need read actual remoteAddr, error print must exec after writeErrorReplyIfNeeded t.OnError(ctx, err, conn) - return nil + return err } if methodInfo.OneWay() { sendMsg = remote.NewMessage(nil, t.svcInfo, ri, remote.Reply, remote.Server) @@ -171,17 +179,20 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { // error cannot be wrapped to print here, so it must exec before NewTransError t.OnError(ctx, err, conn) err = remote.NewTransError(remote.InternalError, err) - closeConn = t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, false) - return nil + if closeConn := t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, false); closeConn { + return err + } + // connection don't need to be closed when the error is return by the server handler + closeConnOutsideIfErr = false + return } remote.FillSendMsgFromRecvMsg(recvMsg, sendMsg) if ctx, err = t.transPipe.Write(ctx, conn, sendMsg); err != nil { - closeConn = true t.OnError(ctx, err, conn) - return nil + return err } - return nil + return } // OnMessage implements the remote.ServerTransHandler interface. diff --git a/pkg/remote/trans/default_server_handler_test.go b/pkg/remote/trans/default_server_handler_test.go index 37a3ec980c..2b095925c2 100644 --- a/pkg/remote/trans/default_server_handler_test.go +++ b/pkg/remote/trans/default_server_handler_test.go @@ -18,11 +18,17 @@ package trans import ( "context" + "errors" "net" "testing" + "github.com/golang/mock/gomock" + "github.com/cloudwego/kitex/internal/mocks" + "github.com/cloudwego/kitex/internal/mocks/stats" + internal_stats "github.com/cloudwego/kitex/internal/stats" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" ) @@ -77,3 +83,110 @@ func TestDefaultSvrTransHandler(t *testing.T) { test.Assert(t, tagEncode == 1, tagEncode) test.Assert(t, tagDecode == 1, tagDecode) } + +func TestSvrTransHandlerBizError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTracer := stats.NewMockTracer(ctrl) + mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes() + mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) { + err := rpcinfo.GetRPCInfo(ctx).Stats().Error() + test.Assert(t, err != nil) + }).AnyTimes() + + buf := remote.NewReaderWriterBuffer(1024) + ext := &MockExtension{ + NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { + return buf + }, + NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { + return buf + }, + } + + tracerCtl := &internal_stats.Controller{} + tracerCtl.Append(mockTracer) + opt := &remote.ServerOption{ + Codec: &MockCodec{ + EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { + return nil + }, + DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + return nil + }, + }, + SvcInfo: mocks.ServiceInfo(), + TracerCtl: tracerCtl, + InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { + rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr) + return ri + }, + } + ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}), + rpcinfo.NewInvocation("", mocks.MockMethod), nil, rpcinfo.NewRPCStats()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + svrHandler, err := NewDefaultSvrTransHandler(opt, ext) + pl := remote.NewTransPipeline(svrHandler) + svrHandler.SetPipeline(pl) + if setter, ok := svrHandler.(remote.InvokeHandleFuncSetter); ok { + setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + return kerrors.ErrBiz.WithCause(errors.New("mock")) + }) + } + test.Assert(t, err == nil) + err = svrHandler.OnRead(ctx, &mocks.Conn{}) + test.Assert(t, err == nil) +} + +func TestSvrTransHandlerReadErr(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTracer := stats.NewMockTracer(ctrl) + mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes() + mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) { + err := rpcinfo.GetRPCInfo(ctx).Stats().Error() + test.Assert(t, err != nil) + }).AnyTimes() + + buf := remote.NewReaderWriterBuffer(1024) + ext := &MockExtension{ + NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { + return buf + }, + NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { + return buf + }, + } + + mockErr := errors.New("mock") + tracerCtl := &internal_stats.Controller{} + tracerCtl.Append(mockTracer) + opt := &remote.ServerOption{ + Codec: &MockCodec{ + EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { + return nil + }, + DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { + return mockErr + }, + }, + SvcInfo: mocks.ServiceInfo(), + TracerCtl: tracerCtl, + InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { + rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr) + return ri + }, + } + ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}), + rpcinfo.NewInvocation("", mocks.MockMethod), nil, rpcinfo.NewRPCStats()) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + + svrHandler, err := NewDefaultSvrTransHandler(opt, ext) + test.Assert(t, err == nil) + err = svrHandler.OnRead(ctx, &mocks.Conn{}) + test.Assert(t, err != nil) + test.Assert(t, errors.Is(err, mockErr)) +} diff --git a/pkg/remote/trans/detection/server_handler.go b/pkg/remote/trans/detection/server_handler.go index d1e0856171..6fb09dfef4 100644 --- a/pkg/remote/trans/detection/server_handler.go +++ b/pkg/remote/trans/detection/server_handler.go @@ -28,22 +28,17 @@ import ( "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" - transNetpoll "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" - "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" ) // NewSvrTransHandlerFactory detection factory construction -func NewSvrTransHandlerFactory() remote.ServerTransHandlerFactory { - return &svrTransHandlerFactory{ - http2: nphttp2.NewSvrTransHandlerFactory(), - netpoll: transNetpoll.NewSvrTransHandlerFactory(), - } +func NewSvrTransHandlerFactory(nonHttp2, http2 remote.ServerTransHandlerFactory) remote.ServerTransHandlerFactory { + return &svrTransHandlerFactory{nonHttp2, http2} } type svrTransHandlerFactory struct { - http2 remote.ServerTransHandlerFactory - netpoll remote.ServerTransHandlerFactory + defaultHandlerFactory remote.ServerTransHandlerFactory + http2HandlerFactory remote.ServerTransHandlerFactory } func (f *svrTransHandlerFactory) MuxEnabled() bool { @@ -53,18 +48,18 @@ func (f *svrTransHandlerFactory) MuxEnabled() bool { func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { t := &svrTransHandler{} var err error - if t.http2, err = f.http2.NewTransHandler(opt); err != nil { + if t.http2Handler, err = f.http2HandlerFactory.NewTransHandler(opt); err != nil { return nil, err } - if t.netpoll, err = f.netpoll.NewTransHandler(opt); err != nil { + if t.defaultHandler, err = f.defaultHandlerFactory.NewTransHandler(opt); err != nil { return nil, err } return t, nil } type svrTransHandler struct { - http2 remote.ServerTransHandler - netpoll remote.ServerTransHandler + defaultHandler remote.ServerTransHandler + http2Handler remote.ServerTransHandler } func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, send remote.Message) (nctx context.Context, err error) { @@ -91,16 +86,18 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { return r.handler.OnRead(r.ctx, conn) } // Check the validity of client preface. - zr := conn.(netpoll.Connection).Reader() - // read at most avoid block - preface, err := zr.Peek(prefaceReadAtMost) - if err != nil { + var ( + preface []byte + err error + ) + npReader := conn.(interface{ Reader() netpoll.Reader }).Reader() + if preface, err = npReader.Peek(prefaceReadAtMost); err != nil { return err } // compare preface one by one - which := t.netpoll + which := t.defaultHandler if bytes.Equal(preface[:prefaceReadAtMost], grpc.ClientPreface[:prefaceReadAtMost]) { - which = t.http2 + which = t.http2Handler ctx, err = which.OnActive(ctx, conn) if err != nil { return err @@ -111,7 +108,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { } func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { - // t.http2 should use the ctx returned by OnActive in r.ctx + // t.http2HandlerFactory should use the ctx returned by OnActive in r.ctx if r, ok := ctx.Value(handlerKey{}).(*handlerWrapper); ok && r.ctx != nil { ctx = r.ctx } @@ -135,28 +132,28 @@ func (t *svrTransHandler) which(ctx context.Context) remote.ServerTransHandler { } func (t *svrTransHandler) SetPipeline(pipeline *remote.TransPipeline) { - t.http2.SetPipeline(pipeline) - t.netpoll.SetPipeline(pipeline) + t.http2Handler.SetPipeline(pipeline) + t.defaultHandler.SetPipeline(pipeline) } func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) { - if t, ok := t.http2.(remote.InvokeHandleFuncSetter); ok { + if t, ok := t.http2Handler.(remote.InvokeHandleFuncSetter); ok { t.SetInvokeHandleFunc(inkHdlFunc) } - if t, ok := t.netpoll.(remote.InvokeHandleFuncSetter); ok { + if t, ok := t.defaultHandler.(remote.InvokeHandleFuncSetter); ok { t.SetInvokeHandleFunc(inkHdlFunc) } } func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { - ctx, err := t.netpoll.OnActive(ctx, conn) + ctx, err := t.defaultHandler.OnActive(ctx, conn) if err != nil { return nil, err } - // svrTransHandler wraps two kinds of ServerTransHandler: http2, none-http2. + // svrTransHandler wraps two kinds of ServerTransHandler: http2HandlerFactory, none-http2HandlerFactory. // We think that one connection only use one type, it doesn't need to do protocol detection for every request. // And ctx is initialized with a new connection, so we put a handlerWrapper into ctx, which for recording - // the actual handler, then the later request don't need to do http2 detection. + // the actual handler, then the later request don't need to do http2HandlerFactory detection. return context.WithValue(ctx, handlerKey{}, &handlerWrapper{}), nil } diff --git a/pkg/remote/trans/detection/server_handler_test.go b/pkg/remote/trans/detection/server_handler_test.go index 3ef7d3283f..bf12778c29 100644 --- a/pkg/remote/trans/detection/server_handler_test.go +++ b/pkg/remote/trans/detection/server_handler_test.go @@ -24,6 +24,8 @@ import ( "testing" mocksklog "github.com/cloudwego/kitex/internal/mocks/klog" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/golang/mock/gomock" @@ -38,7 +40,7 @@ import ( ) func TestServerHandlerCall(t *testing.T) { - transHdler, _ := NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{ + transHdler, _ := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ SvcInfo: mocks.ServiceInfo(), }) @@ -80,7 +82,7 @@ func TestServerHandlerCall(t *testing.T) { npConn.EXPECT().Reader().Return(npReader).AnyTimes() npConn.EXPECT().RemoteAddr().Return(nil).AnyTimes() - transHdler.(*svrTransHandler).netpoll = hdl + transHdler.(*svrTransHandler).defaultHandler = hdl // case1 successful call: onActive() and onRead() all success triggerActiveErr = false @@ -107,7 +109,7 @@ func TestOnError(t *testing.T) { klog.SetLogger(klog.DefaultLogger()) ctrl.Finish() }() - transHdler, err := NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{ + transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ SvcInfo: mocks.ServiceInfo(), }) test.Assert(t, err == nil) @@ -136,7 +138,7 @@ func TestOnError(t *testing.T) { // TestOnInactive covers onInactive() codes to check panic func TestOnInactive(t *testing.T) { - transHdler, err := NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{ + transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ SvcInfo: mocks.ServiceInfo(), }) test.Assert(t, err == nil) diff --git a/pkg/remote/trans/gonet/bytebuffer.go b/pkg/remote/trans/gonet/bytebuffer.go index c50b9f3209..dfae2731c2 100644 --- a/pkg/remote/trans/gonet/bytebuffer.go +++ b/pkg/remote/trans/gonet/bytebuffer.go @@ -52,7 +52,11 @@ func newBufferReadWriter() interface{} { // NewBufferReader creates a new remote.ByteBuffer using the given netpoll.ZeroCopyReader. func NewBufferReader(ir io.Reader) remote.ByteBuffer { rw := rwPool.Get().(*bufferReadWriter) - rw.reader = netpoll.NewReader(ir) + if npReader, ok := ir.(interface{ Reader() netpoll.Reader }); ok { + rw.reader = npReader.Reader() + } else { + rw.reader = netpoll.NewReader(ir) + } rw.ioReader = ir rw.status = remote.BitReadable rw.readSize = 0 diff --git a/pkg/remote/trans/gonet/server_handler_test.go b/pkg/remote/trans/gonet/server_handler_test.go index 2a7a6e3917..3a7d7e5333 100644 --- a/pkg/remote/trans/gonet/server_handler_test.go +++ b/pkg/remote/trans/gonet/server_handler_test.go @@ -36,11 +36,16 @@ func TestOnActive(t *testing.T) { }, }, } + pl := remote.NewTransPipeline(svrTransHdlr) + svrTransHdlr.SetPipeline(pl) + if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok { + setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + return nil + }) + } ctx := context.Background() - ctx, err := svrTransHdlr.OnActive(ctx, conn) - test.Assert(t, err == nil, err) - err = svrTransHdlr.OnRead(ctx, conn) + _, err := svrTransHdlr.OnActive(ctx, conn) test.Assert(t, err == nil, err) } @@ -141,9 +146,9 @@ func TestPanicAfterRead(t *testing.T) { test.Assert(t, err == nil, err) err = svrTransHdlr.OnRead(ctx, conn) - test.Assert(t, err == nil, err) + test.Assert(t, err != nil, err) test.Assert(t, !isInvoked) - test.Assert(t, isClosed) + test.Assert(t, !isClosed) } // TestNoMethodInfo test server_handler without method info success @@ -173,6 +178,6 @@ func TestNoMethodInfo(t *testing.T) { test.Assert(t, err == nil, err) err = svrTransHdlr.OnRead(ctx, conn) - test.Assert(t, err == nil, err) - test.Assert(t, isClosed) + test.Assert(t, err != nil, err) + test.Assert(t, !isClosed) } diff --git a/pkg/remote/trans/gonet/trans_server.go b/pkg/remote/trans/gonet/trans_server.go index f39cd79e44..9b314b9a8d 100644 --- a/pkg/remote/trans/gonet/trans_server.go +++ b/pkg/remote/trans/gonet/trans_server.go @@ -26,6 +26,8 @@ import ( "sync" "time" + "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans" "github.com/cloudwego/kitex/pkg/klog" @@ -77,23 +79,30 @@ func (ts *transServer) BootstrapServer(ln net.Listener) (err error) { for { conn, err := ts.ln.Accept() if err != nil { - klog.Errorf("bootstrap server accept failed, err=%s", err.Error()) + klog.Errorf("KITEX: BootstrapServer accept failed, err=%s", err.Error()) os.Exit(1) } go func() { - ri := ts.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr()) - ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) + var ( + ctx = context.Background() + err error + ) defer func() { transRecover(ctx, conn, "OnRead") - // recycle rpcinfo - rpcinfo.PutRPCInfo(ri) }() + bc := newBufioConn(conn) + ctx, err = ts.transHdlr.OnActive(ctx, bc) + if err != nil { + klog.CtxErrorf(ctx, "KITEX: OnActive error=%s", err) + return + } for { - ts.refreshDeadline(rpcinfo.GetRPCInfo(ctx), conn) - err := ts.transHdlr.OnRead(ctx, conn) + ts.refreshDeadline(rpcinfo.GetRPCInfo(ctx), bc) + err := ts.transHdlr.OnRead(ctx, bc) if err != nil { - ts.onError(ctx, err, conn) - _ = conn.Close() + klog.CtxErrorf(ctx, "KITEX: OnRead Error: %s\n", err.Error()) + ts.onError(ctx, err, bc) + _ = bc.Close() return } } @@ -129,6 +138,70 @@ func (ts *transServer) onError(ctx context.Context, err error, conn net.Conn) { ts.transHdlr.OnError(ctx, err, conn) } +func (ts *transServer) refreshDeadline(ri rpcinfo.RPCInfo, conn net.Conn) { + readTimeout := ri.Config().ReadWriteTimeout() + _ = conn.SetReadDeadline(time.Now().Add(readTimeout)) +} + +// bufioConn implements the net.Conn interface. +type bufioConn struct { + conn net.Conn + r netpoll.Reader +} + +func newBufioConn(c net.Conn) *bufioConn { + return &bufioConn{ + conn: c, + r: netpoll.NewReader(c), + } +} + +func (bc *bufioConn) RawConn() net.Conn { + return bc.conn +} + +func (bc *bufioConn) Read(b []byte) (int, error) { + buf, err := bc.r.Next(len(b)) + if err != nil { + return 0, err + } + copy(b, buf) + return len(b), nil +} + +func (bc *bufioConn) Write(b []byte) (int, error) { + return bc.conn.Write(b) +} + +func (bc *bufioConn) Close() error { + bc.r.Release() + return bc.conn.Close() +} + +func (bc *bufioConn) LocalAddr() net.Addr { + return bc.conn.LocalAddr() +} + +func (bc *bufioConn) RemoteAddr() net.Addr { + return bc.conn.RemoteAddr() +} + +func (bc *bufioConn) SetDeadline(t time.Time) error { + return bc.conn.SetDeadline(t) +} + +func (bc *bufioConn) SetReadDeadline(t time.Time) error { + return bc.conn.SetReadDeadline(t) +} + +func (bc *bufioConn) SetWriteDeadline(t time.Time) error { + return bc.conn.SetWriteDeadline(t) +} + +func (bc *bufioConn) Reader() netpoll.Reader { + return bc.r +} + func transRecover(ctx context.Context, conn net.Conn, funcName string) { panicErr := recover() if panicErr != nil { @@ -139,8 +212,3 @@ func transRecover(ctx context.Context, conn net.Conn, funcName string) { } } } - -func (ts *transServer) refreshDeadline(ri rpcinfo.RPCInfo, conn net.Conn) { - readTimeout := ri.Config().ReadWriteTimeout() - _ = conn.SetReadDeadline(time.Now().Add(readTimeout)) -} diff --git a/pkg/remote/trans/netpoll/server_handler_test.go b/pkg/remote/trans/netpoll/server_handler_test.go index 02a9bc9808..3c7b232e6c 100644 --- a/pkg/remote/trans/netpoll/server_handler_test.go +++ b/pkg/remote/trans/netpoll/server_handler_test.go @@ -21,7 +21,6 @@ import ( "errors" "net" "testing" - "time" "github.com/cloudwego/netpoll" @@ -33,26 +32,33 @@ import ( // TestOnActive test server_handler OnActive success func TestOnActive(t *testing.T) { // 1. prepare mock data - var readTimeout time.Duration conn := &MockNetpollConn{ - SetReadTimeoutFunc: func(timeout time.Duration) (e error) { - readTimeout = timeout - return nil - }, Conn: mocks.Conn{ RemoteAddrFunc: func() (r net.Addr) { return addr }, }, + ReaderFunc: func() (r netpoll.Reader) { + reader := &MockNetpollReader{ + ReleaseFunc: func() (err error) { + return nil + }, + } + return reader + }, + WriterFunc: func() (r netpoll.Writer) { + writer := &MockNetpollWriter{ + FlushFunc: func() (err error) { + return nil + }, + } + return writer + }, } - - // 2. test ctx := context.Background() - ctx, err := svrTransHdlr.OnActive(ctx, conn) - test.Assert(t, err == nil, err) - err = svrTransHdlr.OnRead(ctx, conn) + remote.NewTransPipeline(svrTransHdlr) + _, err := svrTransHdlr.OnActive(ctx, conn) test.Assert(t, err == nil, err) - test.Assert(t, readTimeout == rwTimeout, readTimeout, rwTimeout) } // TestOnRead test server_handler OnRead success @@ -214,11 +220,11 @@ func TestPanicAfterRead(t *testing.T) { test.Assert(t, err == nil, err) err = svrTransHdlr.OnRead(ctx, conn) - test.Assert(t, err == nil, err) + test.Assert(t, err != nil, err) test.Assert(t, isReaderBufReleased) test.Assert(t, !isWriteBufFlushed) test.Assert(t, !isInvoked) - test.Assert(t, isClosed) + test.Assert(t, !isClosed) } // TestNoMethodInfo test server_handler without method info success @@ -268,8 +274,8 @@ func TestNoMethodInfo(t *testing.T) { test.Assert(t, err == nil, err) err = svrTransHdlr.OnRead(ctx, conn) - test.Assert(t, err == nil, err) + test.Assert(t, err != nil, err) test.Assert(t, isReaderBufReleased) test.Assert(t, isWriteBufFlushed) - test.Assert(t, isClosed) + test.Assert(t, !isClosed) } diff --git a/pkg/remote/trans/netpoll/trans_server.go b/pkg/remote/trans/netpoll/trans_server.go index 1522cd34f8..f6db948226 100644 --- a/pkg/remote/trans/netpoll/trans_server.go +++ b/pkg/remote/trans/netpoll/trans_server.go @@ -115,7 +115,10 @@ func (ts *transServer) Shutdown() (err error) { ts.ln.Close() // 2. signal all active connections to close gracefully - g.GracefulShutdown(ctx) + err = g.GracefulShutdown(ctx) + if err != nil { + klog.Warnf("KITEX: server graceful shutdown error: %v", err) + } } } if ts.evl != nil { @@ -154,7 +157,10 @@ func (ts *transServer) onConnRead(ctx context.Context, conn netpoll.Connection) err := ts.transHdlr.OnRead(ctx, conn) if err != nil { ts.onError(ctx, err, conn) - conn.Close() + if conn != nil { + // close the connection if OnRead return error + conn.Close() + } } return nil } diff --git a/pkg/remote/trans/netpoll/trans_server_test.go b/pkg/remote/trans/netpoll/trans_server_test.go index 50dc943a25..e7e9379f8a 100644 --- a/pkg/remote/trans/netpoll/trans_server_test.go +++ b/pkg/remote/trans/netpoll/trans_server_test.go @@ -187,11 +187,16 @@ func TestConnOnActiveAndOnInactivePanic(t *testing.T) { // TestOnConnRead test trans_server onConnRead success func TestConnOnRead(t *testing.T) { // 1. prepare mock data + var isClosed bool conn := &MockNetpollConn{ Conn: mocks.Conn{ RemoteAddrFunc: func() (r net.Addr) { return addr }, + CloseFunc: func() (e error) { + isClosed = true + return nil + }, }, } mockErr := errors.New("mock error") @@ -205,4 +210,5 @@ func TestConnOnRead(t *testing.T) { // 2. test err := transSvr.onConnRead(context.Background(), conn) test.Assert(t, err == nil, err) + test.Assert(t, isClosed) } diff --git a/pkg/remote/trans/netpollmux/mux_conn.go b/pkg/remote/trans/netpollmux/mux_conn.go index 4781781ab0..51f1c60e88 100644 --- a/pkg/remote/trans/netpollmux/mux_conn.go +++ b/pkg/remote/trans/netpollmux/mux_conn.go @@ -115,6 +115,7 @@ func (c *muxCliConn) Close() error { } func (c *muxCliConn) forceClose() error { + c.shardQueue.Close() c.Connection.Close() c.seqIDMap.rangeMap(func(seqID int32, msg EventHandler) { msg.Recv(nil, ErrConnClosed) @@ -169,3 +170,7 @@ type muxConn struct { func (c *muxConn) Put(gt mux.WriterGetter) { c.shardQueue.Add(gt) } + +func (c *muxConn) GracefulShutdown() { + c.shardQueue.Close() +} diff --git a/pkg/remote/trans/netpollmux/server_handler.go b/pkg/remote/trans/netpollmux/server_handler.go index 7a61e5b3d3..6585074bd0 100644 --- a/pkg/remote/trans/netpollmux/server_handler.go +++ b/pkg/remote/trans/netpollmux/server_handler.go @@ -317,26 +317,39 @@ func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error { msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift)) msg.TransInfo().TransStrInfo()[transmeta.HeaderConnectionReadyToReset] = "1" - t.conns.Range(func(k, v interface{}) bool { - wbuf := netpoll.NewLinkBuffer() - bufWriter := np.NewWriterByteBuffer(wbuf) - err := t.codec.Encode(ctx, msg, bufWriter) - bufWriter.Release(err) - if err == nil { - v.(*muxSvrConn).Put(func() (buf netpoll.Writer, isNil bool) { - return wbuf, false - }) - } else { - c := v.(*muxSvrConn) - klog.Warn("KITEX: signal connection closing error:", - err.Error(), c.LocalAddr().String(), "=>", c.RemoteAddr().String()) - } - return true - }) // wait until all notifications are sent and clients stop using those connections done := make(chan struct{}) go func() { + // 1. write control frames to all connections + t.conns.Range(func(k, v interface{}) bool { + sconn := v.(*muxSvrConn) + if !sconn.IsActive() { + return true + } + wbuf := netpoll.NewLinkBuffer() + bufWriter := np.NewWriterByteBuffer(wbuf) + err := t.codec.Encode(ctx, msg, bufWriter) + bufWriter.Release(err) + if err == nil { + sconn.Put(func() (buf netpoll.Writer, isNil bool) { + return wbuf, false + }) + } else { + klog.Warn("KITEX: signal connection closing error:", + err.Error(), sconn.LocalAddr().String(), "=>", sconn.RemoteAddr().String()) + } + return true + }) + // 2. waiting for all tasks finished t.tasks.Wait() + // 3. waiting for all connections have been shutdown gracefully + t.conns.Range(func(k, v interface{}) bool { + sconn := v.(*muxSvrConn) + if sconn.IsActive() { + sconn.GracefulShutdown() + } + return true + }) close(done) }() for { diff --git a/pkg/remote/trans/nphttp2/client_conn_test.go b/pkg/remote/trans/nphttp2/client_conn_test.go index de961d7bbb..b647137202 100644 --- a/pkg/remote/trans/nphttp2/client_conn_test.go +++ b/pkg/remote/trans/nphttp2/client_conn_test.go @@ -59,30 +59,3 @@ func TestClientConn(t *testing.T) { test.Assert(t, err != nil, err) test.Assert(t, n == 0) } - -//func TestClientConnRead(t *testing.T) { -// mockey.PatchConvey("TestClientConnRead", t, func() { -// mockey.Mock((*grpc.Stream).Read).Return(1, io.EOF).Build() -// mockey.Mock((*grpc.Stream).Status).Return(status.New(codes.Internal, "not found")).Build() -// mockey.Mock((*grpc.Stream).BizStatusErr).Return(kerrors.NewGRPCBizStatusError(404, "not found")).Build() -// s := &grpc.Stream{} -// cli := &clientConn{s: s} -// n, err := cli.Read(nil) -// test.Assert(t, n == 1) -// bizErr, _ := kerrors.FromBizStatusError(err) -// test.Assert(t, bizErr.BizStatusCode() == 404) -// test.Assert(t, bizErr.BizMessage() == "not found") -// }) -// mockey.PatchConvey("TestClientConnRead", t, func() { -// mockey.Mock((*grpc.Stream).Read).Return(1, io.EOF).Build() -// mockey.Mock((*grpc.Stream).Status).Return(status.New(codes.Internal, "not found")).Build() -// mockey.Mock((*grpc.Stream).BizStatusErr).Return(nil).Build() -// s := &grpc.Stream{} -// cli := &clientConn{s: s} -// n, err := cli.Read(nil) -// test.Assert(t, err != nil) -// _, isBizErr := kerrors.FromBizStatusError(err) -// test.Assert(t, !isBizErr) -// test.Assert(t, n == 1) -// }) -//} diff --git a/pkg/remote/trans/nphttp2/conn_pool.go b/pkg/remote/trans/nphttp2/conn_pool.go index 323005bf94..9e936c163b 100644 --- a/pkg/remote/trans/nphttp2/conn_pool.go +++ b/pkg/remote/trans/nphttp2/conn_pool.go @@ -18,13 +18,13 @@ package nphttp2 import ( "context" + "crypto/tls" "net" "runtime" "sync" "sync/atomic" "time" - "github.com/cloudwego/netpoll" "golang.org/x/sync/singleflight" "github.com/cloudwego/kitex/pkg/klog" @@ -108,9 +108,12 @@ func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, netwo if err != nil { return nil, err } + if opts.TLSConfig != nil { + conn = tls.Client(conn, opts.TLSConfig) + } return grpc.NewClientTransport( ctx, - conn.(netpoll.Connection), + conn, opts, p.remoteService, func(grpc.GoAwayReason) { diff --git a/pkg/remote/trans/nphttp2/grpc/framer.go b/pkg/remote/trans/nphttp2/grpc/framer.go index b90fe290b5..aadf661129 100644 --- a/pkg/remote/trans/nphttp2/grpc/framer.go +++ b/pkg/remote/trans/nphttp2/grpc/framer.go @@ -34,8 +34,8 @@ type framer struct { func newFramer(conn net.Conn, writeBufferSize, readBufferSize, maxHeaderListSize uint32) *framer { var r netpoll.Reader - if npconn, ok := conn.(netpoll.Connection); ok { - r = npconn.Reader() + if npConn, ok := conn.(interface{ Reader() netpoll.Reader }); ok { + r = npConn.Reader() } else { r = netpoll.NewReader(conn) } diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 931ebe510c..acbc86f2a6 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -982,6 +982,7 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.ReadFrame() if err != nil { + klog.Errorf("KITEX: grpc readFrame failed, error=%s", err.Error()) t.Close() // this kicks off resetTransport, so must be last before return return } @@ -1024,6 +1025,7 @@ func (t *http2Client) reader() { continue } else { // Transport error. + klog.Errorf("KITEX: grpc readFrame failed, error=%s", err.Error()) t.Close() return } diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index 0c4feeb8f8..e1f67588cb 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -378,7 +378,7 @@ func (t *http2Server) operateHeaders(frame *grpcframe.MetaHeadersFrame, handle f // HandleStreams receives incoming streams using the given handler. This is // typically run in a separate goroutine. // traceCtx attaches trace to ctx and returns the new context. -func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { +func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) (err error) { defer close(t.readerDone) for { t.controlBuf.throttle() @@ -404,11 +404,11 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. } if err == io.EOF || err == io.ErrUnexpectedEOF || errors.Is(err, netpoll.ErrEOF) { t.Close() - return + return err } klog.CtxWarnf(t.ctx, "transport: http2Server.HandleStreams failed to read frame: %v", err) t.Close() - return + return err } switch frame := frame.(type) { case *grpcframe.MetaHeadersFrame: diff --git a/pkg/remote/trans/nphttp2/grpc/transport.go b/pkg/remote/trans/nphttp2/grpc/transport.go index b00e9782e2..3f97de6e31 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport.go +++ b/pkg/remote/trans/nphttp2/grpc/transport.go @@ -26,6 +26,7 @@ package grpc import ( "bytes" "context" + "crypto/tls" "errors" "fmt" "io" @@ -578,6 +579,8 @@ type ConnectOptions struct { MaxHeaderListSize *uint32 // ShortConn indicates whether the connection will be reused from grpc conn pool ShortConn bool + // TLSConfig + TLSConfig *tls.Config } // NewServerTransport creates a ServerTransport with conn or non-nil error @@ -685,7 +688,7 @@ type ClientTransport interface { // Write methods for a given Stream will be called serially. type ServerTransport interface { // HandleStreams receives incoming streams using the given handler. - HandleStreams(func(*Stream), func(context.Context, string) context.Context) + HandleStreams(func(*Stream), func(context.Context, string) context.Context) error // WriteHeader sends the header metadata for the given stream. // WriteHeader may not be called on all streams. diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index 1e5bdc30a1..cf2a8720f3 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -24,6 +24,7 @@ import ( "runtime/debug" "strings" "sync" + "time" "github.com/cloudwego/netpoll" @@ -95,7 +96,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans) tr := svrTrans.tr - tr.HandleStreams(func(s *grpcTransport.Stream) { + return tr.HandleStreams(func(s *grpcTransport.Stream) { gofunc.GoFunc(ctx, func() { ri := svrTrans.pool.Get().(rpcinfo.RPCInfo) rCtx := rpcinfo.NewCtxWithRPCInfo(s.Context(), ri) @@ -203,7 +204,6 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { }, func(ctx context.Context, method string) context.Context { return ctx }) - return nil } // msg 是解码后的实例,如 Arg 或 Result, 触发上层处理,用于异步 和 服务端处理 @@ -224,9 +224,13 @@ type SvrTrans struct { func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { // set readTimeout to infinity to avoid streaming break // use keepalive to check the health of connection - conn.(netpoll.Connection).SetReadTimeout(grpcTransport.Infinity) + if npConn, ok := conn.(netpoll.Connection); ok { + npConn.SetReadTimeout(grpcTransport.Infinity) + } else { + conn.SetReadDeadline(time.Now().Add(grpcTransport.Infinity)) + } - tr, err := grpcTransport.NewServerTransport(ctx, conn.(netpoll.Connection), t.opt.GRPCCfg) + tr, err := grpcTransport.NewServerTransport(ctx, conn, t.opt.GRPCCfg) if err != nil { return nil, err } diff --git a/pkg/remote/transmeta/metakey.go b/pkg/remote/transmeta/metakey.go index 26b8efdc2d..24799d0465 100644 --- a/pkg/remote/transmeta/metakey.go +++ b/pkg/remote/transmeta/metakey.go @@ -17,6 +17,8 @@ // Package transmeta . package transmeta +import "github.com/bytedance/gopkg/cloud/metainfo" + // Keys in mesh header. const ( MeshVersion uint16 = iota @@ -61,3 +63,13 @@ const ( HeaderConnectionReadyToReset = "crrst" HeaderProcessAtTime = "K_ProcessAtTime" ) + +// key of acl token +// You can set up acl token through metainfo. +// eg: +// +// ctx = metainfo.WithValue(ctx, "gdpr-token", "your token") +const ( + // GDPRToken is used to set up gdpr token into InfoIDACLToken + GDPRToken = metainfo.PrefixTransient + "gdpr-token" +) diff --git a/pkg/retry/backup_retryer.go b/pkg/retry/backup_retryer.go index 2bb4db89d1..47f4191eac 100644 --- a/pkg/retry/backup_retryer.go +++ b/pkg/retry/backup_retryer.go @@ -51,6 +51,11 @@ type backupRetryer struct { errMsg string } +type resultWrapper struct { + ri rpcinfo.RPCInfo + err error +} + // ShouldRetry implements the Retryer interface. func (r *backupRetryer) ShouldRetry(ctx context.Context, err error, callTimes int, req interface{}, cbKey string) (string, bool) { r.RLock() @@ -89,7 +94,7 @@ func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpc var recordCostDoing int32 = 0 var abort int32 = 0 // notice: buff num of chan is very important here, it cannot less than call times, or the below chan receive will block - done := make(chan error, retryTimes+1) + done := make(chan *resultWrapper, retryTimes+1) cbKey, _ := r.cbContainer.cbCtl.GetKey(ctx, req) timer := time.NewTimer(retryDelay) defer func() { @@ -108,16 +113,19 @@ func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpc if atomic.LoadInt32(&abort) == 1 { return } - var e error + var ( + e error + cRI rpcinfo.RPCInfo + ) defer func() { if panicInfo := recover(); panicInfo != nil { e = panicToErr(ctx, panicInfo, firstRI) } - done <- e + done <- &resultWrapper{cRI, e} }() ct := atomic.AddInt32(&callTimes, 1) callStart := time.Now() - _, _, e = rpcCall(ctx, r) + cRI, _, e = rpcCall(ctx, r) recordCost(ct, callStart, &recordCostDoing, &callCosts, &abort, e) if r.cbContainer.cbStat { circuitbreak.RecordStat(ctx, req, nil, e, cbKey, r.cbContainer.cbCtl, r.cbContainer.cbPanel) @@ -130,15 +138,15 @@ func (r *backupRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rpc doCall = true timer.Reset(retryDelay) } - case e := <-done: - if e != nil && errors.Is(e, kerrors.ErrRPCFinish) { + case res := <-done: + if res.err != nil && errors.Is(res.err, kerrors.ErrRPCFinish) { // To ignore resp concurrent write, the later response won't do decode and return ErrRPCFinish. // But if the cost of decode is long, ErrRPCFinish will return before previous normal call. continue } atomic.StoreInt32(&abort, 1) - recordRetryInfo(firstRI, atomic.LoadInt32(&callTimes), callCosts.String()) - return false, e + recordRetryInfo(firstRI, res.ri, atomic.LoadInt32(&callTimes), callCosts.String()) + return false, res.err } } } diff --git a/pkg/retry/failure_retryer.go b/pkg/retry/failure_retryer.go index 81932d0aef..ac8ea0a77e 100644 --- a/pkg/retry/failure_retryer.go +++ b/pkg/retry/failure_retryer.go @@ -137,10 +137,6 @@ func (r *failureRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rp // user specified resp to do retry continue } - if i > 0 { - // monitor report just use firstRI - rpcinfo.PutRPCInfo(cRI) - } break } else { if i == retryTimes { @@ -152,7 +148,7 @@ func (r *failureRetryer) Do(ctx context.Context, rpcCall RPCCallFunc, firstRI rp } } } - recordRetryInfo(firstRI, callTimes, callCosts.String()) + recordRetryInfo(firstRI, cRI, callTimes, callCosts.String()) if err == nil && callTimes == 1 { return true, nil } diff --git a/pkg/retry/policy.go b/pkg/retry/policy.go index 575fe11f6e..4701790fd0 100644 --- a/pkg/retry/policy.go +++ b/pkg/retry/policy.go @@ -59,6 +59,7 @@ func BuildBackupRequest(p *BackupPolicy) Policy { } // Policy contains all retry policies +// DON'T FORGET to update Equals() and DeepCopy() if you add new fields type Policy struct { Enable bool `json:"enable"` // 0 is failure retry, 1 is backup @@ -68,7 +69,20 @@ type Policy struct { BackupPolicy *BackupPolicy `json:"backup_policy,omitempty"` } +func (p *Policy) DeepCopy() *Policy { + if p == nil { + return nil + } + return &Policy{ + Enable: p.Enable, + Type: p.Type, + FailurePolicy: p.FailurePolicy.DeepCopy(), + BackupPolicy: p.BackupPolicy.DeepCopy(), + } +} + // FailurePolicy for failure retry +// DON'T FORGET to update Equals() and DeepCopy() if you add new fields type FailurePolicy struct { StopPolicy StopPolicy `json:"stop_policy"` BackOffPolicy *BackOffPolicy `json:"backoff_policy,omitempty"` @@ -93,6 +107,7 @@ func (p FailurePolicy) IsErrorRetryNonNil() bool { } // BackupPolicy for backup request +// DON'T FORGET to update Equals() and DeepCopy() if you add new fields type BackupPolicy struct { RetryDelayMS uint32 `json:"retry_delay_ms"` StopPolicy StopPolicy `json:"stop_policy"` @@ -119,6 +134,7 @@ type CBPolicy struct { } // BackOffPolicy is the BackOff policy. +// DON'T FORGET to update Equals() and DeepCopy() if you add new fields type BackOffPolicy struct { BackOffType BackOffType `json:"backoff_type"` CfgItems map[BackOffCfgKey]float64 `json:"cfg_items,omitempty"` @@ -191,6 +207,18 @@ func (p *FailurePolicy) Equals(np *FailurePolicy) bool { return true } +func (p *FailurePolicy) DeepCopy() *FailurePolicy { + if p == nil { + return nil + } + return &FailurePolicy{ + StopPolicy: p.StopPolicy, + BackOffPolicy: p.BackOffPolicy.DeepCopy(), + RetrySameNode: p.RetrySameNode, + ShouldResultRetry: p.ShouldResultRetry, // don't need DeepCopy + } +} + // Equals to check if BackupPolicy is equal func (p *BackupPolicy) Equals(np *BackupPolicy) bool { if p == nil { @@ -212,6 +240,17 @@ func (p *BackupPolicy) Equals(np *BackupPolicy) bool { return true } +func (p *BackupPolicy) DeepCopy() *BackupPolicy { + if p == nil { + return nil + } + return &BackupPolicy{ + RetryDelayMS: p.RetryDelayMS, + StopPolicy: p.StopPolicy, // not a pointer, will copy the value here + RetrySameNode: p.RetrySameNode, + } +} + // Equals to check if BackOffPolicy is equal. func (p *BackOffPolicy) Equals(np *BackOffPolicy) bool { if p == nil { @@ -235,6 +274,27 @@ func (p *BackOffPolicy) Equals(np *BackOffPolicy) bool { return true } +func (p *BackOffPolicy) DeepCopy() *BackOffPolicy { + if p == nil { + return nil + } + return &BackOffPolicy{ + BackOffType: p.BackOffType, + CfgItems: p.copyCfgItems(), + } +} + +func (p *BackOffPolicy) copyCfgItems() map[BackOffCfgKey]float64 { + if p.CfgItems == nil { + return nil + } + cfgItems := make(map[BackOffCfgKey]float64, len(p.CfgItems)) + for k, v := range p.CfgItems { + cfgItems[k] = v + } + return cfgItems +} + func checkCBErrorRate(p *CBPolicy) error { if p.ErrorRate <= 0 || p.ErrorRate > 0.3 { return fmt.Errorf("invalid retry circuit breaker rate, errRate=%0.2f", p.ErrorRate) diff --git a/pkg/retry/policy_test.go b/pkg/retry/policy_test.go index 8e22e3bfe3..e8a6ca9c30 100644 --- a/pkg/retry/policy_test.go +++ b/pkg/retry/policy_test.go @@ -17,6 +17,7 @@ package retry import ( + "reflect" "testing" jsoniter "github.com/json-iterator/go" @@ -410,3 +411,286 @@ func genRPCInfo() rpcinfo.RPCInfo { ri := rpcinfo.NewRPCInfo(to, to, rpcinfo.NewInvocation("", method), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) return ri } + +func genRPCInfoWithRemoteTag(tags map[string]string) rpcinfo.RPCInfo { + to := remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{Method: method, Tags: tags}, method).ImmutableView() + ri := rpcinfo.NewRPCInfo(to, to, rpcinfo.NewInvocation("", method), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) + return ri +} + +func TestBackOffPolicy_copyCfgItems(t *testing.T) { + type fields struct { + BackOffType BackOffType + CfgItems map[BackOffCfgKey]float64 + } + tests := []struct { + name string + fields fields + want map[BackOffCfgKey]float64 + }{ + { + name: "nil_map", + fields: fields{ + BackOffType: NoneBackOffType, + CfgItems: nil, + }, + want: nil, + }, + { + name: "empty_map", + fields: fields{ + BackOffType: NoneBackOffType, + CfgItems: make(map[BackOffCfgKey]float64), + }, + want: make(map[BackOffCfgKey]float64), + }, + { + name: "not_empty_map", + fields: fields{ + BackOffType: NoneBackOffType, + CfgItems: map[BackOffCfgKey]float64{ + MinMSBackOffCfgKey: 1, + MaxMSBackOffCfgKey: 2, + }, + }, + want: map[BackOffCfgKey]float64{ + MinMSBackOffCfgKey: 1, + MaxMSBackOffCfgKey: 2, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &BackOffPolicy{ + BackOffType: tt.fields.BackOffType, + CfgItems: tt.fields.CfgItems, + } + if got := p.copyCfgItems(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("copyCfgItems() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBackOffPolicy_DeepCopy(t *testing.T) { + type fields struct { + p *BackOffPolicy + } + tests := []struct { + name string + fields fields + want *BackOffPolicy + }{ + { + name: "nil_policy", + fields: fields{ + p: nil, + }, + want: nil, + }, + { + name: "empty_policy", + fields: fields{ + p: &BackOffPolicy{ + BackOffType: NoneBackOffType, + CfgItems: make(map[BackOffCfgKey]float64), + }, + }, + want: &BackOffPolicy{ + BackOffType: NoneBackOffType, + CfgItems: make(map[BackOffCfgKey]float64), + }, + }, + { + name: "not_empty_policy", + fields: fields{ + p: &BackOffPolicy{ + BackOffType: NoneBackOffType, + CfgItems: map[BackOffCfgKey]float64{ + MinMSBackOffCfgKey: 1, + MaxMSBackOffCfgKey: 2, + }, + }, + }, + want: &BackOffPolicy{ + BackOffType: NoneBackOffType, + CfgItems: map[BackOffCfgKey]float64{ + MinMSBackOffCfgKey: 1, + MaxMSBackOffCfgKey: 2, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.fields.p.DeepCopy(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("DeepCopy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBackupPolicy_DeepCopy(t *testing.T) { + type fields struct { + p *BackupPolicy + } + tests := []struct { + name string + fields fields + want *BackupPolicy + }{ + { + name: "nil_policy", + fields: fields{ + p: nil, + }, + want: nil, + }, + { + name: "empty_policy", + fields: fields{ + p: &BackupPolicy{ + RetryDelayMS: 0, + StopPolicy: StopPolicy{}, + RetrySameNode: false, + }, + }, + want: &BackupPolicy{ + RetryDelayMS: 0, + StopPolicy: StopPolicy{}, + RetrySameNode: false, + }, + }, + { + name: "not_empty_policy", + fields: fields{ + p: &BackupPolicy{ + RetryDelayMS: 1, + StopPolicy: StopPolicy{}, + RetrySameNode: true, + }, + }, + want: &BackupPolicy{ + RetryDelayMS: 1, + StopPolicy: StopPolicy{}, + RetrySameNode: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := tt.fields.p + if got := p.DeepCopy(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("DeepCopy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFailurePolicy_DeepCopy(t *testing.T) { + type fields struct { + p *FailurePolicy + } + tests := []struct { + name string + fields fields + want *FailurePolicy + }{ + { + name: "nil_policy", + fields: fields{ + p: nil, + }, + want: nil, + }, + { + name: "empty_policy", + fields: fields{ + p: &FailurePolicy{}, + }, + want: &FailurePolicy{}, + }, + { + name: "not_empty_policy", + fields: fields{ + p: &FailurePolicy{ + StopPolicy: StopPolicy{}, + BackOffPolicy: &BackOffPolicy{}, + RetrySameNode: true, + ShouldResultRetry: &ShouldResultRetry{}, + }, + }, + want: &FailurePolicy{ + StopPolicy: StopPolicy{}, + BackOffPolicy: &BackOffPolicy{}, + RetrySameNode: true, + ShouldResultRetry: &ShouldResultRetry{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.fields.p.DeepCopy(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("DeepCopy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPolicy_DeepCopy(t *testing.T) { + type fields struct { + p *Policy + } + tests := []struct { + name string + fields fields + want *Policy + }{ + { + name: "nil_policy", + fields: fields{ + p: nil, + }, + want: nil, + }, + { + name: "empty_policy", + fields: fields{ + p: &Policy{}, + }, + want: &Policy{}, + }, + { + name: "not_empty_policy", + fields: fields{ + p: &Policy{ + Enable: true, + Type: BackupType, + FailurePolicy: &FailurePolicy{ + RetrySameNode: true, + }, + BackupPolicy: &BackupPolicy{ + RetryDelayMS: 1000, + }, + }, + }, + want: &Policy{ + Enable: true, + Type: BackupType, + FailurePolicy: &FailurePolicy{ + RetrySameNode: true, + }, + BackupPolicy: &BackupPolicy{ + RetryDelayMS: 1000, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.fields.p.DeepCopy(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("DeepCopy() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/retry/retryer_test.go b/pkg/retry/retryer_test.go index 5e497a3912..21241c29f1 100644 --- a/pkg/retry/retryer_test.go +++ b/pkg/retry/retryer_test.go @@ -30,6 +30,12 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" ) +var ( + remoteTagKey = "k" + remoteTagValue = "v" + remoteTags = map[string]string{remoteTagKey: remoteTagValue} +) + // test new retry container func TestNewRetryContainer(t *testing.T) { rc := NewRetryContainerWithCB(nil, nil) @@ -416,7 +422,7 @@ func TestSpecifiedErrorRetry(t *testing.T) { if newVal == 1 { return genRPCInfo(), nil, remote.NewTransErrorWithMsg(1000, "mock") } else { - return genRPCInfo(), nil, nil + return genRPCInfoWithRemoteTag(remoteTags), nil, nil } } ctx := context.Background() @@ -437,6 +443,9 @@ func TestSpecifiedErrorRetry(t *testing.T) { ok, err := rc.WithRetryIfNeeded(ctx, Policy{}, retryWithTransError, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) // case2: specified method retry with error, but use backup request config cannot be effective atomic.StoreInt32(&callTimes, 0) @@ -451,6 +460,7 @@ func TestSpecifiedErrorRetry(t *testing.T) { rc = NewRetryContainer() err = rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(10))}, shouldResultRetry) test.Assert(t, err == nil, err) + ri = genRPCInfo() ok, err = rc.WithRetryIfNeeded(ctx, Policy{}, retryWithTransError, ri, nil) test.Assert(t, err != nil, err) test.Assert(t, !ok) @@ -468,17 +478,24 @@ func TestSpecifiedErrorRetry(t *testing.T) { rc = NewRetryContainer() err = rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) test.Assert(t, err == nil, err) + ri = genRPCInfo() ok, err = rc.WithRetryIfNeeded(ctx, Policy{}, retryWithTransError, ri, nil) test.Assert(t, err != nil) test.Assert(t, !ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) // case4: all error retry atomic.StoreInt32(&callTimes, 0) rc = NewRetryContainer() p := BuildFailurePolicy(NewFailurePolicyWithResultRetry(AllErrorRetry())) + ri = genRPCInfo() ok, err = rc.WithRetryIfNeeded(ctx, p, retryWithTransError, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, !ok) + v, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) } // test specified resp to retry @@ -500,7 +517,7 @@ func TestSpecifiedRespRetry(t *testing.T) { return genRPCInfo(), retryResult, nil } else { retryResult.SetResult(noRetryResp) - return genRPCInfo(), retryResult, nil + return genRPCInfoWithRemoteTag(remoteTags), retryResult, nil } } ctx := context.Background() @@ -522,12 +539,17 @@ func TestSpecifiedRespRetry(t *testing.T) { test.Assert(t, err == nil, err) test.Assert(t, retryResult.GetResult() == noRetryResp, retryResult) test.Assert(t, !ok) + v, ok := ri.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) // case2 specified method retry with resp, but use backup request config cannot be effective atomic.StoreInt32(&callTimes, 0) rc = NewRetryContainer() err = rc.Init(map[string]Policy{Wildcard: BuildBackupRequest(NewBackupPolicy(100))}, shouldResultRetry) test.Assert(t, err == nil, err) + ri = genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) ok, err = rc.WithRetryIfNeeded(ctx, Policy{}, retryWithResp, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, retryResult.GetResult() == retryResp, retryResp) @@ -546,10 +568,14 @@ func TestSpecifiedRespRetry(t *testing.T) { rc = NewRetryContainer() err = rc.Init(map[string]Policy{method: BuildFailurePolicy(NewFailurePolicy())}, shouldResultRetry) test.Assert(t, err == nil, err) + ri = genRPCInfo() + ctx = rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) ok, err = rc.WithRetryIfNeeded(ctx, Policy{}, retryWithResp, ri, nil) test.Assert(t, err == nil, err) test.Assert(t, retryResult.GetResult() == retryResp, retryResult) test.Assert(t, ok) + _, ok = ri.To().Tag(remoteTagKey) + test.Assert(t, !ok) } // test different method use different retry policy @@ -641,16 +667,24 @@ func TestBackupPolicyCall(t *testing.T) { test.Assert(t, err == nil, err) callTimes := int32(0) - ri := genRPCInfo() - ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + firstRI := genRPCInfo() + secondRI := genRPCInfoWithRemoteTag(remoteTags) + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, firstRI) ok, err := rc.WithRetryIfNeeded(ctx, Policy{}, func(ctx context.Context, r Retryer) (rpcinfo.RPCInfo, interface{}, error) { atomic.AddInt32(&callTimes, 1) - time.Sleep(time.Millisecond * 50) - return ri, nil, nil - }, ri, nil) + if atomic.LoadInt32(&callTimes) == 1 { + // mock timeout for the first request and get the response of the backup request. + time.Sleep(time.Millisecond * 50) + return firstRI, nil, nil + } + return secondRI, nil, nil + }, firstRI, nil) test.Assert(t, err == nil, err) test.Assert(t, atomic.LoadInt32(&callTimes) == 2) test.Assert(t, !ok) + v, ok := firstRI.To().Tag(remoteTagKey) + test.Assert(t, ok) + test.Assert(t, v == remoteTagValue) } // test policy noRetry call diff --git a/pkg/retry/util.go b/pkg/retry/util.go index a22cb413ac..8428fea60b 100644 --- a/pkg/retry/util.go +++ b/pkg/retry/util.go @@ -154,12 +154,19 @@ func appendErrMsg(err error, msg string) { } } -func recordRetryInfo(ri rpcinfo.RPCInfo, callTimes int32, lastCosts string) { +func recordRetryInfo(firstRI, lastRI rpcinfo.RPCInfo, callTimes int32, lastCosts string) { if callTimes > 1 { - if me := remoteinfo.AsRemoteInfo(ri.To()); me != nil { - me.SetTag(rpcinfo.RetryTag, strconv.Itoa(int(callTimes)-1)) + if firstRe := remoteinfo.AsRemoteInfo(firstRI.To()); firstRe != nil { + // use the remoteInfo of the RPCCall that returns finally, in case the remoteInfo is modified during the call. + if lastRI != nil { + if lastRe := remoteinfo.AsRemoteInfo(lastRI.To()); lastRe != nil { + firstRe.CopyFrom(lastRe) + } + } + + firstRe.SetTag(rpcinfo.RetryTag, strconv.Itoa(int(callTimes)-1)) // record last cost - me.SetTag(rpcinfo.RetryLastCostTag, lastCosts) + firstRe.SetTag(rpcinfo.RetryLastCostTag, lastCosts) } } } diff --git a/pkg/rpcinfo/ctx_test.go b/pkg/rpcinfo/ctx_test.go index aec8299d75..c2f09eb793 100644 --- a/pkg/rpcinfo/ctx_test.go +++ b/pkg/rpcinfo/ctx_test.go @@ -22,6 +22,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/utils" ) func TestNewCtxWithRPCInfo(t *testing.T) { @@ -49,12 +50,28 @@ func TestGetCtxRPCInfo(t *testing.T) { } func TestPutRPCInfo(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, nil) + // pre-declared variables and initialization to ensure readability + method := "TestMethod" + svcName := "TestServiceName" + netAddr := utils.NewNetAddr("TestNetWork", "TestAddress") + tags := map[string]string{"MapTestKey": "MapTestValue"} + + ri := rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(svcName, method, netAddr, tags), + rpcinfo.NewEndpointInfo(svcName, method, netAddr, tags), + rpcinfo.NewInvocation(svcName, method), + rpcinfo.NewRPCConfig(), + rpcinfo.NewRPCStats(), + ) // nil safe rpcinfo.PutRPCInfo(nil) rpcinfo.PutRPCInfo(ri) - // TODO: mock each field to check if them are recycled respectively - // ... + test.Assert(t, ri.From() == nil) + test.Assert(t, ri.To() == nil) + test.Assert(t, ri.Stats() == nil) + test.Assert(t, ri.Invocation() == nil) + test.Assert(t, ri.Config() == nil) + test.Assert(t, ri.Stats() == nil) } diff --git a/pkg/rpcinfo/remoteinfo/remoteInfo.go b/pkg/rpcinfo/remoteinfo/remoteInfo.go index 17331bf7fb..0e01f94aa6 100644 --- a/pkg/rpcinfo/remoteinfo/remoteInfo.go +++ b/pkg/rpcinfo/remoteinfo/remoteInfo.go @@ -43,6 +43,7 @@ type RemoteInfo interface { // SetRemoteAddr tries to set the network address of the discovery.Instance hold by RemoteInfo. // The result indicates whether the modification is successful. SetRemoteAddr(addr net.Addr) (ok bool) + CopyFrom(from RemoteInfo) ImmutableView() rpcinfo.EndpointInfo } @@ -172,6 +173,22 @@ func (ri *remoteInfo) ForceSetTag(key, value string) { ri.tags[key] = value } +// CopyFrom copies the input RemoteInfo. +// Not deepcopy. +func (ri *remoteInfo) CopyFrom(from RemoteInfo) { + if from == nil { + return + } + ri.Lock() + f := from.(*remoteInfo) + ri.serviceName = f.serviceName + ri.instance = f.instance + ri.tags = f.tags + ri.method = f.method + ri.tagLocks = f.tagLocks + ri.Unlock() +} + // ImmutableView implements rpcinfo.MutableEndpointInfo. func (ri *remoteInfo) ImmutableView() rpcinfo.EndpointInfo { return ri diff --git a/pkg/transmeta/http2.go b/pkg/transmeta/http2.go index 5477dd625e..a7a57b0e9d 100644 --- a/pkg/transmeta/http2.go +++ b/pkg/transmeta/http2.go @@ -25,8 +25,6 @@ import ( "github.com/cloudwego/kitex/pkg/remote/transmeta" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/transport" - - "github.com/bytedance/gopkg/cloud/metainfo" ) // ClientHTTP2Handler default global client metadata handler @@ -44,13 +42,15 @@ func (*clientHTTP2Handler) OnConnectStream(ctx context.Context) (context.Context if !isGRPC(ri) { return ctx, nil } - md := metadata.MD{} + // append more meta into current metadata + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + md = metadata.MD{} + } md.Append(transmeta.HTTPDestService, ri.To().ServiceName()) md.Append(transmeta.HTTPDestMethod, ri.To().Method()) md.Append(transmeta.HTTPSourceService, ri.From().ServiceName()) md.Append(transmeta.HTTPSourceMethod, ri.From().Method()) - // get custom kv from metainfo and set to grpc metadata - metainfo.ToHTTPHeader(ctx, metainfo.HTTPHeader(md)) return metadata.NewOutgoingContext(ctx, md), nil } @@ -84,9 +84,6 @@ func (*serverHTTP2Handler) OnReadStream(ctx context.Context) (context.Context, e if !ok { return ctx, nil } - // get custom kv from metadata and set to context - ctx = metainfo.FromHTTPHeader(ctx, metainfo.HTTPHeader(md)) - ctx = metainfo.TransferForward(ctx) ci := rpcinfo.AsMutableEndpointInfo(ri.From()) if ci != nil { if v := md.Get(transmeta.HTTPSourceService); len(v) != 0 { diff --git a/pkg/transmeta/http2_test.go b/pkg/transmeta/http2_test.go index 562b3b2eb4..6f16590e7c 100644 --- a/pkg/transmeta/http2_test.go +++ b/pkg/transmeta/http2_test.go @@ -38,7 +38,6 @@ func TestIsGRPC(t *testing.T) { args args want bool }{ - // TODO: Add test cases. {"with ttheader", args{ri: rpcinfo.NewRPCInfo( nil, nil, @@ -75,6 +74,83 @@ func TestIsGRPC(t *testing.T) { nil, )}, true}, + {"with http", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.HTTP) + return cfg + }(), + nil, + )}, false}, + + {"with ttheader and http", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeader) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.HTTP) + return cfg + }(), + nil, + )}, false}, + + {"with ttheader framed and http", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderFramed) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.HTTP) + return cfg + }(), + nil, + )}, false}, + + {"with ttheader and ttheader framed", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeader) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderFramed) + return cfg + }(), + nil, + )}, false}, + + {"with http and grpc", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.HTTP) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + return cfg + }(), + nil, + )}, true}, + + {"with ttheader and grpc", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeader) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + return cfg + }(), + nil, + )}, true}, + {"with ttheader framed and grpc", args{ri: rpcinfo.NewRPCInfo( nil, nil, @@ -87,6 +163,63 @@ func TestIsGRPC(t *testing.T) { }(), nil, )}, true}, + + {"with ttheader,http and grpc", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeader) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.HTTP) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + return cfg + }(), + nil, + )}, true}, + + {"with ttheader framed,http and grpc", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderFramed) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.HTTP) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + return cfg + }(), + nil, + )}, true}, + + {"with ttheader ,ttheader framed and http", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeader) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderFramed) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.HTTP) + return cfg + }(), + nil, + )}, false}, + + {"with ttheader ,ttheader framed,http and grpc", args{ri: rpcinfo.NewRPCInfo( + nil, + nil, + nil, + func() rpcinfo.RPCConfig { + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeader) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderFramed) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.HTTP) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + return cfg + }(), + nil, + )}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/transmeta/metainfo.go b/pkg/transmeta/metainfo.go index 6c831d83f4..4ce8c8b7de 100644 --- a/pkg/transmeta/metainfo.go +++ b/pkg/transmeta/metainfo.go @@ -22,16 +22,44 @@ import ( "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/rpcinfo" ) // singletons . var ( - MetainfoClientHandler remote.MetaHandler = new(metainfoClientHandler) - MetainfoServerHandler remote.MetaHandler = new(metainfoServerHandler) + MetainfoClientHandler = new(metainfoClientHandler) + MetainfoServerHandler = new(metainfoServerHandler) + + _ remote.MetaHandler = MetainfoClientHandler + _ remote.StreamingMetaHandler = MetainfoClientHandler + _ remote.MetaHandler = MetainfoServerHandler + _ remote.StreamingMetaHandler = MetainfoServerHandler ) type metainfoClientHandler struct{} +func (mi *metainfoClientHandler) OnConnectStream(ctx context.Context) (context.Context, error) { + // gRPC send meta when connection is establishing + // so put metainfo into metadata before sending http2 headers + ri := rpcinfo.GetRPCInfo(ctx) + if isGRPC(ri) { + // append kitex metainfo into metadata list + // kitex metainfo key starts with " rpc-transit-xxx " + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + md = metadata.MD{} + } + metainfo.ToHTTPHeader(ctx, metainfo.HTTPHeader(md)) + ctx = metadata.NewOutgoingContext(ctx, md) + } + return ctx, nil +} + +func (mi *metainfoClientHandler) OnReadStream(ctx context.Context) (context.Context, error) { + return ctx, nil +} + func (mi *metainfoClientHandler) WriteMeta(ctx context.Context, sendMsg remote.Message) (context.Context, error) { if metainfo.HasMetaInfo(ctx) { kvs := make(map[string]string) @@ -57,7 +85,26 @@ func (mi *metainfoServerHandler) ReadMeta(ctx context.Context, recvMsg remote.Me ctx = metainfo.SetMetaInfoFromMap(ctx, kvs) } ctx = metainfo.WithBackwardValuesToSend(ctx) - ctx = metainfo.TransferForward(ctx) + return ctx, nil +} + +func (mi *metainfoServerHandler) OnConnectStream(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +func (mi *metainfoServerHandler) OnReadStream(ctx context.Context) (context.Context, error) { + // gRPC server receives meta from http2 header when reading stream + // read all metadata and put those with kitex-metainfo kvs into ctx + ri := rpcinfo.GetRPCInfo(ctx) + if isGRPC(ri) { + // Attach kitex metainfo into context + mdata, ok := metadata.FromIncomingContext(ctx) + if ok { + ctx = metainfo.FromHTTPHeader(ctx, metainfo.HTTPHeader(mdata)) + ctx = metainfo.WithBackwardValuesToSend(ctx) + ctx = metainfo.TransferForward(ctx) + } + } return ctx, nil } diff --git a/pkg/transmeta/metainfo_test.go b/pkg/transmeta/metainfo_test.go index 5965508018..f3576425e1 100644 --- a/pkg/transmeta/metainfo_test.go +++ b/pkg/transmeta/metainfo_test.go @@ -103,6 +103,7 @@ func TestServerReadMetainfo(t *testing.T) { msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift)) ctx, err = MetainfoServerHandler.ReadMeta(ctx0, msg) + ctx = metainfo.TransferForward(ctx) tvs = metainfo.GetAllValues(ctx) pvs = metainfo.GetAllPersistentValues(ctx) test.Assert(t, err == nil) @@ -112,7 +113,7 @@ func TestServerReadMetainfo(t *testing.T) { ctx = metainfo.TransferForward(ctx) tvs = metainfo.GetAllValues(ctx) pvs = metainfo.GetAllPersistentValues(ctx) - test.Assert(t, len(tvs) == 0) + test.Assert(t, len(tvs) == 0, len(tvs)) test.Assert(t, len(pvs) == 1 && pvs["pk"] == "pv") } diff --git a/pkg/utils/config.go b/pkg/utils/config.go index 1a3f6e8c97..783f9f89f8 100644 --- a/pkg/utils/config.go +++ b/pkg/utils/config.go @@ -50,7 +50,13 @@ func GetConfFile() string { return path.Join(GetConfDir(), file) } +// GetEnvLogDir is to get log dir from env. +func GetEnvLogDir() string { + return os.Getenv(EnvLogDir) +} + // GetLogDir gets dir of log file. +// Deprecated: it is suggested to use GetEnvLogDir instead of GetLogDir, and GetEnvLogDir won't return default log dir. func GetLogDir() string { if logDir := os.Getenv(EnvLogDir); logDir != "" { return logDir diff --git a/pkg/utils/config_test.go b/pkg/utils/config_test.go index f313c9e9e2..2759798619 100644 --- a/pkg/utils/config_test.go +++ b/pkg/utils/config_test.go @@ -44,12 +44,22 @@ func TestGetConfFile(t *testing.T) { test.Assert(t, confFile == path.Join(GetConfDir(), DefaultConfFile)) } +func TestGetEnvLogDir(t *testing.T) { + os.Setenv(EnvLogDir, "test_log") + logDir := GetEnvLogDir() + test.Assert(t, logDir == "test_log") + + os.Unsetenv(EnvLogDir) + logDir = GetEnvLogDir() + test.Assert(t, logDir == "") +} + func TestGetLogDir(t *testing.T) { os.Setenv(EnvLogDir, "test_log") logDir := GetLogDir() test.Assert(t, logDir == "test_log") - os.Setenv(EnvLogDir, "") + os.Unsetenv(EnvLogDir) logDir = GetLogDir() test.Assert(t, logDir == DefaultLogDir) } diff --git a/pkg/utils/sharedticker.go b/pkg/utils/sharedticker.go new file mode 100644 index 0000000000..c8ed894004 --- /dev/null +++ b/pkg/utils/sharedticker.go @@ -0,0 +1,102 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "sync" + "time" +) + +type TickerTask interface { + Tick() +} + +// NewSharedTicker constructs a SharedTicker with specified interval. +func NewSharedTicker(interval time.Duration) *SharedTicker { + return &SharedTicker{ + Interval: interval, + tasks: map[TickerTask]struct{}{}, + stopChan: make(chan struct{}, 1), + } +} + +type SharedTicker struct { + sync.Mutex + started bool + Interval time.Duration + tasks map[TickerTask]struct{} + stopChan chan struct{} +} + +func (t *SharedTicker) Add(b TickerTask) { + if b == nil { + return + } + t.Lock() + // Add task + t.tasks[b] = struct{}{} + if !t.started { + t.started = true + go t.Tick(t.Interval) + } + t.Unlock() +} + +func (t *SharedTicker) Delete(b TickerTask) { + t.Lock() + // Delete from tasks + delete(t.tasks, b) + // no tasks remaining then stop the Tick + if len(t.tasks) == 0 { + // unblocked when multi Delete call + select { + case t.stopChan <- struct{}{}: + t.started = false + default: + } + } + t.Unlock() +} + +func (t *SharedTicker) Closed() bool { + t.Lock() + defer t.Unlock() + return !t.started +} + +func (t *SharedTicker) Tick(interval time.Duration) { + var wg sync.WaitGroup + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + t.Lock() + for b := range t.tasks { + wg.Add(1) + go func(b TickerTask) { + defer wg.Done() + b.Tick() + }(b) + } + t.Unlock() + wg.Wait() + case <-t.stopChan: + return + } + } +} diff --git a/pkg/utils/sharedticker_test.go b/pkg/utils/sharedticker_test.go new file mode 100644 index 0000000000..8dd636bd2e --- /dev/null +++ b/pkg/utils/sharedticker_test.go @@ -0,0 +1,83 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + + mockutils "github.com/cloudwego/kitex/internal/mocks/utils" + "github.com/cloudwego/kitex/internal/test" +) + +func TestSharedTickerAdd(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + rt := mockutils.NewMockTickerTask(ctrl) + rt.EXPECT().Tick().AnyTimes() + st := NewSharedTicker(1) + st.Add(nil) + test.Assert(t, len(st.tasks) == 0) + st.Add(rt) + test.Assert(t, len(st.tasks) == 1) + test.Assert(t, !st.Closed()) +} + +func TestSharedTickerDeleteAndClose(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + st := NewSharedTicker(1) + var ( + num = 10 + tasks = make([]TickerTask, num) + ) + for i := 0; i < num; i++ { + rt := mockutils.NewMockTickerTask(ctrl) + rt.EXPECT().Tick().AnyTimes() + tasks[i] = rt + st.Add(rt) + } + test.Assert(t, len(st.tasks) == num) + for i := 0; i < num; i++ { + st.Delete(tasks[i]) + } + test.Assert(t, len(st.tasks) == 0) + test.Assert(t, st.Closed()) +} + +func TestSharedTickerTick(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + duration := 100 * time.Millisecond + st := NewSharedTicker(duration) + var ( + num = 10 + tasks = make([]TickerTask, num) + ) + for i := 0; i < num; i++ { + rt := mockutils.NewMockTickerTask(ctrl) + rt.EXPECT().Tick().MinTimes(1) // all tasks should be refreshed once during the test + tasks[i] = rt + st.Add(rt) + } + time.Sleep(150 * time.Millisecond) +} diff --git a/server/option_test.go b/server/option_test.go index 6c9ca79da5..946e6f6ec1 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -30,7 +30,9 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" "github.com/cloudwego/kitex/pkg/remote/trans/detection" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/remote/trans/netpollmux" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/stats" @@ -286,7 +288,7 @@ func TestMuxTransportOption(t *testing.T) { err = svr1.Run() test.Assert(t, err == nil, err) iSvr1 := svr1.(*server) - test.DeepEqual(t, iSvr1.opt.RemoteOpt.SvrHandlerFactory, detection.NewSvrTransHandlerFactory()) + test.DeepEqual(t, iSvr1.opt.RemoteOpt.SvrHandlerFactory, detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory())) svr2 := NewServer(WithMuxTransport()) time.AfterFunc(100*time.Millisecond, func() { diff --git a/server/server_test.go b/server/server_test.go index d59e1145ae..58388aa4be 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -375,10 +375,10 @@ func TestServerBoundHandler(t *testing.T) { }, wantInbounds: []remote.InboundHandler{ bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, false), - bound.NewTransMetaHandler([]remote.MetaHandler{noopMetahandler{}, transmeta.MetainfoServerHandler}), + bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), }, wantOutbounds: []remote.OutboundHandler{ - bound.NewTransMetaHandler([]remote.MetaHandler{noopMetahandler{}, transmeta.MetainfoServerHandler}), + bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), }, }, { @@ -490,10 +490,10 @@ func TestServerBoundHandler(t *testing.T) { WithMetaHandler(noopMetahandler{}), }, wantInbounds: []remote.InboundHandler{ - bound.NewTransMetaHandler([]remote.MetaHandler{noopMetahandler{}, transmeta.MetainfoServerHandler}), + bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), }, wantOutbounds: []remote.OutboundHandler{ - bound.NewTransMetaHandler([]remote.MetaHandler{noopMetahandler{}, transmeta.MetainfoServerHandler}), + bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), }, }, } diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index 81f50fc7b3..c332bc9227 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -97,6 +97,7 @@ func (a *Arguments) buildFlags(version string) *flag.FlagSet { f.DurationVar(&a.ThriftPluginTimeLimit, "thrift-plugin-time-limit", generator.DefaultThriftPluginTimeLimit, "Specify thrift plugin execution time limit.") f.Var(&a.ThriftPlugins, "thrift-plugin", "Specify thrift plugin arguments for the thrift compiler.") f.Var(&a.ProtobufOptions, "protobuf", "Specify arguments for the protobuf compiler.") + f.Var(&a.ProtobufPlugins, "protobuf-plugin", "Specify protobuf plugin arguments for the protobuf compiler.(plugin_name:options:out_dir)") f.BoolVar(&a.CombineService, "combine-service", false, "Combine services in root thrift file.") f.BoolVar(&a.CopyIDL, "copy-idl", false, @@ -107,6 +108,10 @@ func (a *Arguments) buildFlags(version string) *flag.FlagSet { "Use frugal to compile arguments and results when new clients and servers.") f.BoolVar(&a.Record, "record", false, "Record Kitex cmd into kitex-all.sh.") + f.StringVar(&a.TemplateDir, "template-dir", "", + "Use custom template to generate codes.") + f.StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") a.RecordCmd = os.Args a.Version = version a.ThriftOptions = append(a.ThriftOptions, @@ -151,6 +156,10 @@ func (a *Arguments) ParseArgs(version string) { a.checkIDL(f.Args()) a.checkServiceName() + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } a.checkPath() } @@ -187,6 +196,10 @@ func (a *Arguments) checkServiceName() { os.Exit(2) } } else { + if a.TemplateDir != "" { + log.Warn("-template-dir and -service cannot be specified at the same time") + os.Exit(2) + } a.GenerateMain = true } } @@ -215,7 +228,7 @@ func (a *Arguments) checkPath() { log.Warn("Get GOPATH/src relpath failed:", err.Error()) os.Exit(1) } - a.PackagePrefix = filepath.Join(a.PackagePrefix, generator.KitexGenPath) + a.PackagePrefix = filepath.Join(a.PackagePrefix, a.GenPath) } else { if a.ModuleName == "" { log.Warn("Outside of $GOPATH. Please specify a module name with the '-module' flag.") @@ -236,13 +249,13 @@ func (a *Arguments) checkPath() { log.Warn("Get package prefix failed:", err.Error()) os.Exit(1) } - a.PackagePrefix = filepath.Join(a.ModuleName, a.PackagePrefix, generator.KitexGenPath) + a.PackagePrefix = filepath.Join(a.ModuleName, a.PackagePrefix, a.GenPath) } else { if err = initGoMod(pathToGo, a.ModuleName); err != nil { log.Warn("Init go mod failed:", err.Error()) os.Exit(1) } - a.PackagePrefix = filepath.Join(a.ModuleName, generator.KitexGenPath) + a.PackagePrefix = filepath.Join(a.ModuleName, a.GenPath) } } @@ -296,7 +309,7 @@ func (a *Arguments) BuildCmd(out io.Writer) *exec.Cmd { cmd.Args = append(cmd.Args, "-r") } cmd.Args = append(cmd.Args, - "-o", generator.KitexGenPath, + "-o", a.GenPath, "-g", gas, "-p", "kitex="+exe+":"+kas, ) @@ -317,7 +330,7 @@ func (a *Arguments) BuildCmd(out io.Writer) *exec.Cmd { for _, inc := range a.Includes { cmd.Args = append(cmd.Args, "-I", inc) } - outPath := filepath.Join(".", generator.KitexGenPath) + outPath := filepath.Join(".", a.GenPath) if a.Use == "" { os.MkdirAll(outPath, 0o755) } else { @@ -331,6 +344,19 @@ func (a *Arguments) BuildCmd(out io.Writer) *exec.Cmd { for _, po := range a.ProtobufOptions { cmd.Args = append(cmd.Args, "--kitex_opt="+po) } + for _, p := range a.ProtobufPlugins { + pluginParams := strings.Split(p, ":") + if len(pluginParams) != 3 { + log.Warnf("Failed to get the correct protoc plugin parameters for %. Please specify the protoc plugin in the form of \"plugin_name:options:out_dir\"", p) + os.Exit(1) + } + // pluginParams[0] -> plugin name, pluginParams[1] -> plugin options, pluginParams[2] -> out_dir + cmd.Args = append(cmd.Args, + fmt.Sprintf("--%s_out=%s", pluginParams[0], pluginParams[2]), + fmt.Sprintf("--%s_opt=%s", pluginParams[0], pluginParams[1]), + ) + } + cmd.Args = append(cmd.Args, a.IDL) } log.Info(strings.ReplaceAll(strings.Join(cmd.Args, " "), kas, fmt.Sprintf("%q", kas))) diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index 5416f737d7..18f858653d 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -18,10 +18,12 @@ import ( "bytes" "flag" "os" + "os/exec" "strings" "github.com/cloudwego/kitex" kargs "github.com/cloudwego/kitex/tool/cmd/kitex/args" + "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/pluginmode/protoc" "github.com/cloudwego/kitex/tool/internal_pkg/pluginmode/thriftgo" ) @@ -66,9 +68,21 @@ func main() { if args.Use != "" { out := strings.TrimSpace(out.String()) if strings.HasSuffix(out, thriftgo.TheUseOptionMessage) { - os.Exit(0) + goto NormalExit } } os.Exit(1) } +NormalExit: + if args.IDLType == "thrift" { + cmd := "go mod edit -replace github.com/apache/thrift=github.com/apache/thrift@v0.13.0" + argv := strings.Split(cmd, " ") + err := exec.Command(argv[0], argv[1:]...).Run() + + res := "Done" + if err != nil { + res = err.Error() + } + log.Warn("Adding apache/thrift@v0.13.0 to go.mod for generated code ..........", res) + } } diff --git a/tool/internal_pkg/generator/completor.go b/tool/internal_pkg/generator/completor.go index 8c61082ab7..3c696ab8b6 100644 --- a/tool/internal_pkg/generator/completor.go +++ b/tool/internal_pkg/generator/completor.go @@ -189,3 +189,111 @@ func (c *completer) CompleteMethods() (*File, error) { } return &File{Name: c.handlerPath, Content: buf.String()}, nil } + +type commonCompleter struct { + path string + pkg *PackageInfo + update *Update +} + +func (c *commonCompleter) Complete() (*File, error) { + var w bytes.Buffer + // get AST of main package + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, c.path, nil, parser.ParseComments) + if err != nil { + err = fmt.Errorf("go/parser failed to parse the file: %s, err: %v", c.path, err) + log.Warnf("NOTICE: This is not a bug. We cannot update the file %s because your codes failed to compile. Fix the compile errors and try again.\n%s", c.path, err.Error()) + return nil, err + } + + newMethods, err := c.compare(f) + if err != nil { + return nil, err + } + if len(newMethods) == 0 { + return nil, errNoNewMethod + } + err = c.addImport(&w, newMethods, fset, f) + if err != nil { + return nil, fmt.Errorf("add imports failed error: %v", err) + } + err = c.addImplementations(&w, newMethods) + if err != nil { + return nil, fmt.Errorf("add implements failed error: %v", err) + } + return &File{Name: c.path, Content: w.String()}, nil +} + +func (c *commonCompleter) compare(f *ast.File) ([]*MethodInfo, error) { + var newMethods []*MethodInfo + tmp := c.pkg.Methods + methods := c.pkg.AllMethods() + for _, m := range methods { + c.pkg.Methods = []*MethodInfo{m} + keyTask := &Task{ + Text: c.update.Key, + } + key, err := keyTask.RenderString(c.pkg) + if err != nil { + return newMethods, err + } + have := false + for _, d := range f.Decls { + if fd, ok := d.(*ast.FuncDecl); ok { + _, fn := parseFuncDecl(fd) + if fn == key { + have = true + break + } + } + } + if !have { + newMethods = append(newMethods, m) + } + } + + c.pkg.Methods = tmp + return newMethods, nil +} + +// add imports for new methods +func (c *commonCompleter) addImport(w io.Writer, newMethods []*MethodInfo, fset *token.FileSet, handlerAST *ast.File) error { + existImports := make(map[string]bool) + for _, i := range handlerAST.Imports { + existImports[strings.Trim(i.Path.Value, "\"")] = true + } + tmp := c.pkg.Methods + c.pkg.Methods = newMethods + for _, i := range c.update.ImportTpl { + importTask := &Task{ + Text: i, + } + content, err := importTask.RenderString(c.pkg) + if err != nil { + return err + } + if _, ok := existImports[strings.Trim(content, "\"")]; !ok { + astutil.AddImport(fset, handlerAST, strings.Trim(content, "\"")) + } + } + c.pkg.Methods = tmp + printer.Fprint(w, fset, handlerAST) + return nil +} + +func (c *commonCompleter) addImplementations(w io.Writer, newMethods []*MethodInfo) error { + tmp := c.pkg.Methods + c.pkg.Methods = newMethods + // generate implements of new methods + appendTask := &Task{ + Text: c.update.AppendTpl, + } + content, err := appendTask.RenderString(c.pkg) + if err != nil { + return err + } + _, err = w.Write([]byte(content)) + c.pkg.Methods = tmp + return err +} diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go new file mode 100644 index 0000000000..0547b43cc1 --- /dev/null +++ b/tool/internal_pkg/generator/custom_template.go @@ -0,0 +1,207 @@ +// Copyright 2021 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generator + +import ( + "fmt" + "io/ioutil" + "os" + "path" + "path/filepath" + + "gopkg.in/yaml.v3" + + "github.com/cloudwego/kitex/tool/internal_pkg/log" + "github.com/cloudwego/kitex/tool/internal_pkg/util" +) + +var DefaultDelimiters = [2]string{"{{", "}}"} + +type updateType string + +const ( + skip updateType = "skip" + cover updateType = "cover" + incrementalUpdate updateType = "append" +) + +type Update struct { + // update type: skip / cover / append. Default is skip. + // If `LoopMethod` is true, only Type field is effect and no append behavior. + Type string `yaml:"type,omitempty"` + // Match key in append type. If the rendered key exists in the file, the method will be skipped. + Key string `yaml:"key,omitempty"` + // Append template. Use it to render append content. + AppendTpl string `yaml:"append_tpl,omitempty"` + // Append import template. Use it to render import content to append. + ImportTpl []string `yaml:"import_tpl,omitempty"` +} + +type Template struct { + // The generated path and its filename. For example: biz/test.go + // will generate test.go in biz directory. + Path string `yaml:"path,omitempty"` + // Render template content, currently only supports go template syntax + Body string `yaml:"body,omitempty"` + // define update behavior + UpdateBehavior *Update `yaml:"update_behavior,omitempty"` + // If set this field, kitex will generate file by cycle. For example: + // test_a/test_b/{{ .Name}}_test.go + LoopMethod bool `yaml:"loop_method,omitempty"` +} + +type customGenerator struct { + fs []*File + pkg *PackageInfo + basePath string +} + +func NewCustomGenerator(pkg *PackageInfo, basePath string) *customGenerator { + return &customGenerator{ + pkg: pkg, + basePath: basePath, + } +} + +func (c *customGenerator) loopGenerate(tpl *Template) error { + tmp := c.pkg.Methods + m := c.pkg.AllMethods() + for _, method := range m { + c.pkg.Methods = []*MethodInfo{method} + pathTask := &Task{ + Text: tpl.Path, + } + renderPath, err := pathTask.RenderString(c.pkg) + if err != nil { + return err + } + filePath := filepath.Join(c.basePath, renderPath) + // update + if util.Exists(filePath) && updateType(tpl.UpdateBehavior.Type) == skip { + continue + } + task := &Task{ + Name: path.Base(renderPath), + Path: filePath, + Text: tpl.Body, + } + f, err := task.Render(c.pkg) + if err != nil { + return err + } + c.fs = append(c.fs, f) + } + c.pkg.Methods = tmp + return nil +} + +func (c *customGenerator) commonGenerate(tpl *Template) error { + pathTask := &Task{ + Text: tpl.Path, + } + renderPath, err := pathTask.RenderString(c.pkg) + if err != nil { + return err + } + filePath := filepath.Join(c.basePath, renderPath) + update := util.Exists(filePath) + if update && updateType(tpl.UpdateBehavior.Type) == skip { + log.Infof("skip generate file %s", tpl.Path) + return nil + } + + var f *File + if update && updateType(tpl.UpdateBehavior.Type) == incrementalUpdate { + cc := &commonCompleter{ + path: filePath, + pkg: c.pkg, + update: tpl.UpdateBehavior, + } + f, err = cc.Complete() + if err != nil { + return err + } + } else { + // just create dir + if tpl.Path[len(tpl.Path)-1] == '/' { + os.MkdirAll(filePath, 0o755) + return nil + } + + task := &Task{ + Name: path.Base(tpl.Path), + Path: filePath, + Text: tpl.Body, + } + + f, err = task.Render(c.pkg) + if err != nil { + return err + } + } + + c.fs = append(c.fs, f) + return nil +} + +func (g *generator) GenerateCustomPackage(pkg *PackageInfo) (fs []*File, err error) { + g.updatePackageInfo(pkg) + + t, err := readTemplates(g.TemplateDir) + if err != nil { + return nil, err + } + + cg := NewCustomGenerator(pkg, g.OutputPath) + for _, tpl := range t { + // special handling Methods field + if tpl.LoopMethod { + err = cg.loopGenerate(tpl) + if err != nil { + return cg.fs, err + } + } else { + err = cg.commonGenerate(tpl) + if err != nil { + return cg.fs, err + } + } + } + + return cg.fs, nil +} + +func readTemplates(dir string) ([]*Template, error) { + files, _ := ioutil.ReadDir(dir) + var ts []*Template + for _, f := range files { + if f.Name() != ExtensionFilename { + path := filepath.Join(dir, f.Name()) + tplData, err := ioutil.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read layout config from %s failed, err: %v", path, err.Error()) + } + t := &Template{ + UpdateBehavior: &Update{Type: string(skip)}, + } + if err = yaml.Unmarshal(tplData, t); err != nil { + return nil, fmt.Errorf("unmarshal layout config failed, err: %v", err.Error()) + } + ts = append(ts, t) + } + } + + return ts, nil +} diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 070b8f759e..b16e3097d8 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -36,13 +36,14 @@ const ( BuildFileName = "build.sh" BootstrapFileName = "bootstrap.sh" - ToolVersionFileName = "kitex.yaml" + ToolVersionFileName = "kitex_info.yaml" HandlerFileName = "handler.go" MainFileName = "main.go" ClientFileName = "client.go" ServerFileName = "server.go" InvokerFileName = "invoker.go" ServiceFileName = "*service.go" + ExtensionFilename = "extensions.yaml" DefaultThriftPluginTimeLimit = time.Minute ) @@ -92,6 +93,7 @@ func AddGlobalDependency(ref, path string) bool { type Generator interface { GenerateService(pkg *PackageInfo) ([]*File, error) GenerateMainPackage(pkg *PackageInfo) ([]*File, error) + GenerateCustomPackage(pkg *PackageInfo) ([]*File, error) } // Config . @@ -114,6 +116,7 @@ type Config struct { CombineService bool // combine services to one service CopyIDL bool ThriftPlugins util.StringSlice + ProtobufPlugins util.StringSlice Features []feature FrugalPretouch bool ThriftPluginTimeLimit time.Duration @@ -123,6 +126,10 @@ type Config struct { Record bool RecordCmd []string + + TemplateDir string + + GenPath string } // Pack packs the Config into a slice of "key=val" strings. @@ -232,13 +239,29 @@ func (c *Config) AddFeature(key string) bool { // ApplyExtension applies template extension. func (c *Config) ApplyExtension() error { - if c.ExtensionFile == "" { + templateExtExist := false + path := filepath.Join(c.TemplateDir, ExtensionFilename) + if c.TemplateDir != "" && util.Exists(path) { + templateExtExist = true + } + + if c.ExtensionFile == "" && !templateExtExist { return nil } ext := new(TemplateExtension) - if err := ext.FromJSONFile(c.ExtensionFile); err != nil { - return fmt.Errorf("read template extension %q failed: %s", c.ExtensionFile, err.Error()) + if c.ExtensionFile != "" { + if err := ext.FromYAMLFile(c.ExtensionFile); err != nil { + return fmt.Errorf("read template extension %q failed: %s", c.ExtensionFile, err.Error()) + } + } + + if templateExtExist { + yamlExt := new(TemplateExtension) + if err := yamlExt.FromYAMLFile(path); err != nil { + return fmt.Errorf("read template extension %q failed: %s", path, err.Error()) + } + ext.Merge(yamlExt) } for _, fn := range ext.FeatureNames { @@ -361,10 +384,7 @@ func (g *generator) GenerateMainPackage(pkg *PackageInfo) (fs []*File, err error func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) { g.updatePackageInfo(pkg) - output := filepath.Join(g.OutputPath, KitexGenPath) - if pkg.Namespace != "" { - output = filepath.Join(output, strings.ReplaceAll(pkg.Namespace, ".", "/")) - } + output := filepath.Join(g.OutputPath, util.CombineOutputPath(g.GenPath, pkg.Namespace)) svcPkg := strings.ToLower(pkg.ServiceName) output = filepath.Join(output, svcPkg) ext := g.tmplExt @@ -433,6 +453,7 @@ func (g *generator) updatePackageInfo(pkg *PackageInfo) { pkg.Features = g.Features pkg.ExternalKitexGen = g.Use pkg.FrugalPretouch = g.FrugalPretouch + pkg.Module = g.ModuleName if pkg.Dependencies == nil { pkg.Dependencies = make(map[string]string) } diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index ab0a66f539..b25f8cbff7 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -46,11 +46,13 @@ func TestConfig_Pack(t *testing.T) { CombineService bool CopyIDL bool ThriftPlugins util.StringSlice + ProtobufPlugins util.StringSlice Features []feature FrugalPretouch bool Record bool RecordCmd string ThriftPluginTimeLimit time.Duration + TemplateDir string } tests := []struct { name string @@ -61,7 +63,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "ExtensionFile=", "Record=false", "RecordCmd="}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath="}, }, } for _, tt := range tests { @@ -88,6 +90,7 @@ func TestConfig_Pack(t *testing.T) { Features: tt.fields.Features, FrugalPretouch: tt.fields.FrugalPretouch, ThriftPluginTimeLimit: tt.fields.ThriftPluginTimeLimit, + TemplateDir: tt.fields.TemplateDir, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { t.Errorf("Config.Pack() = \n%v\nwant\n%v", gotRes, tt.wantRes) @@ -117,6 +120,7 @@ func TestConfig_Unpack(t *testing.T) { CopyIDL bool Features []feature FrugalPretouch bool + TemplateDir string } type args struct { args []string @@ -157,6 +161,7 @@ func TestConfig_Unpack(t *testing.T) { CopyIDL: tt.fields.CopyIDL, Features: tt.fields.Features, FrugalPretouch: tt.fields.FrugalPretouch, + TemplateDir: tt.fields.TemplateDir, } if err := c.Unpack(tt.args.args); (err != nil) != tt.wantErr { t.Errorf("Config.Unpack() error = %v, wantErr %v", err, tt.wantErr) diff --git a/tool/internal_pkg/generator/template.go b/tool/internal_pkg/generator/template.go index 32f54de4c0..db347903d2 100644 --- a/tool/internal_pkg/generator/template.go +++ b/tool/internal_pkg/generator/template.go @@ -16,43 +16,46 @@ package generator import ( "encoding/json" + "fmt" "io/ioutil" + + "gopkg.in/yaml.v3" ) // APIExtension contains segments to extend an API template. type APIExtension struct { // ImportPaths contains a list of import path that the file should add to the import list. // The paths must be registered with the TemplateExtension's Dependencies fields. - ImportPaths []string `json:"import_paths,omitempty"` + ImportPaths []string `json:"import_paths,omitempty" yaml:"import_paths,omitempty"` // Code snippets to be inserted in the NewX function before assembling options. // It must be a template definition with a name identical to the one kitex uses. - ExtendOption string `json:"extend_option,omitempty"` + ExtendOption string `json:"extend_option,omitempty" yaml:"extend_option,omitempty"` // Code snippets to be appended to the file. // It must be a template definition with a name identical to the one kitex uses. - ExtendFile string `json:"extend_file,omitempty"` + ExtendFile string `json:"extend_file,omitempty" yaml:"extend_file,omitempty"` } // TemplateExtension extends templates that generates files in *service packages. type TemplateExtension struct { // FeatureNames registers some names to the scope for the code generating phrase, where templates can use the `HasFeature` function to query. - FeatureNames []string `json:"feature_names,omitempty"` + FeatureNames []string `json:"feature_names,omitempty" yaml:"feature_names,omitempty"` // EnableFeatures marks on which features that `HasFeature` queries should return true. - EnableFeatures []string `json:"enable_features,omitempty"` + EnableFeatures []string `json:"enable_features,omitempty" yaml:"enable_features,omitempty"` // Dependencies is a mapping from import path to package names/aliases. - Dependencies map[string]string `json:"dependencies,omitempty"` + Dependencies map[string]string `json:"dependencies,omitempty" yaml:"dependencies,omitempty"` // Extension for client.go . - ExtendClient *APIExtension `json:"extend_client,omitempty"` + ExtendClient *APIExtension `json:"extend_client,omitempty" yaml:"extend_client,omitempty"` // Extension for server.go . - ExtendServer *APIExtension `json:"extend_server,omitempty"` + ExtendServer *APIExtension `json:"extend_server,omitempty" yaml:"extend_server,omitempty"` // Extension for invoker.go . - ExtendInvoker *APIExtension `json:"extend_invoker,omitempty"` + ExtendInvoker *APIExtension `json:"extend_invoker,omitempty" yaml:"extend_invoker,omitempty"` } // FromJSONFile unmarshals a TemplateExtension with JSON format from the given file. @@ -75,3 +78,72 @@ func (p *TemplateExtension) ToJSONFile(filename string) error { } return ioutil.WriteFile(filename, data, 0o644) } + +// FromYAMLFile unmarshals a TemplateExtension with YAML format from the given file. +func (p *TemplateExtension) FromYAMLFile(filename string) error { + if p == nil { + return nil + } + data, err := ioutil.ReadFile(filename) + if err != nil { + return err + } + return yaml.Unmarshal(data, p) +} + +func (p *TemplateExtension) ToYAMLFile(filename string) error { + data, err := yaml.Marshal(p) + if err != nil { + return err + } + return ioutil.WriteFile(filename, data, 0o644) +} + +func (p *TemplateExtension) Merge(other *TemplateExtension) { + if other == nil { + return + } + p.FeatureNames = append(p.FeatureNames, other.FeatureNames...) + p.EnableFeatures = append(p.EnableFeatures, other.EnableFeatures...) + + if other.Dependencies != nil { + if p.Dependencies == nil { + p.Dependencies = other.Dependencies + } else { + for k, v := range other.Dependencies { + p.Dependencies[k] = v + } + } + } + + if other.ExtendClient != nil { + if p.ExtendClient == nil { + p.ExtendClient = other.ExtendClient + } else { + p.ExtendClient.Merge(other.ExtendClient) + } + } + if other.ExtendServer != nil { + if p.ExtendServer == nil { + p.ExtendServer = other.ExtendServer + } else { + p.ExtendServer.Merge(other.ExtendServer) + } + } + if other.ExtendInvoker != nil { + if p.ExtendInvoker == nil { + p.ExtendInvoker = other.ExtendInvoker + } else { + p.ExtendInvoker.Merge(other.ExtendInvoker) + } + } +} + +func (a *APIExtension) Merge(other *APIExtension) { + if other == nil { + return + } + a.ImportPaths = append(a.ImportPaths, other.ImportPaths...) + a.ExtendOption = fmt.Sprintf("%v\n%v", a.ExtendOption, other.ExtendOption) + a.ExtendFile = fmt.Sprintf("%v\n%v", a.ExtendFile, other.ExtendFile) +} diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index e5afc845fe..f56e6c297a 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -45,6 +45,7 @@ type PackageInfo struct { ExternalKitexGen string Features []feature FrugalPretouch bool + Module string } // AddImport . @@ -100,7 +101,7 @@ type ServiceInfo struct { // AllMethods returns all methods that the service have. func (s *ServiceInfo) AllMethods() (ms []*MethodInfo) { - ms = s.Methods + ms = append(ms, s.Methods...) for base := s.Base; base != nil; base = base.Base { ms = append(base.Methods, ms...) } @@ -135,11 +136,14 @@ type Parameter struct { } var funcs = map[string]interface{}{ - "ToLower": strings.ToLower, - "LowerFirst": util.LowerFirst, - "UpperFirst": util.UpperFirst, - "NotPtr": util.NotPtr, - "HasFeature": HasFeature, + "ToLower": strings.ToLower, + "LowerFirst": util.LowerFirst, + "UpperFirst": util.UpperFirst, + "NotPtr": util.NotPtr, + "ReplaceString": util.ReplaceString, + "SnakeString": util.SnakeString, + "HasFeature": HasFeature, + "FilterImports": FilterImports, } var templateNames = []string{ @@ -257,3 +261,40 @@ func (t *Task) Render(data interface{}) (*File, error) { } return &File{t.Path, buf.String()}, nil } + +func (t *Task) RenderString(data interface{}) (string, error) { + if t.Template == nil { + err := t.Build() + if err != nil { + return "", err + } + } + + var buf bytes.Buffer + err := t.ExecuteTemplate(&buf, t.Name, data) + if err != nil { + return "", err + } + return buf.String(), nil +} + +func FilterImports(Imports map[string]map[string]bool, ms []*MethodInfo) map[string]map[string]bool { + res := map[string]map[string]bool{} + for _, m := range ms { + if m.Resp != nil { + for _, dep := range m.Resp.Deps { + if _, ok := Imports[dep.ImportPath]; ok { + res[dep.ImportPath] = Imports[dep.ImportPath] + } + } + } + for _, arg := range m.Args { + for _, dep := range arg.Deps { + if _, ok := Imports[dep.ImportPath]; ok { + res[dep.ImportPath] = Imports[dep.ImportPath] + } + } + } + } + return res +} diff --git a/tool/internal_pkg/pluginmode/protoc/plugin.go b/tool/internal_pkg/pluginmode/protoc/plugin.go index a1600f01bf..9cc51349bf 100644 --- a/tool/internal_pkg/pluginmode/protoc/plugin.go +++ b/tool/internal_pkg/pluginmode/protoc/plugin.go @@ -225,6 +225,21 @@ func (pp *protocPlugin) process(gen *protogen.Plugin) { } } + if pp.Config.TemplateDir != "" { + if len(pp.Services) == 0 { + gen.Error(errors.New("no service defined")) + return + } + pp.ServiceInfo = pp.Services[len(pp.Services)-1] + fs, err := pp.kg.GenerateCustomPackage(&pp.PackageInfo) + if err != nil { + pp.err = err + } + for _, f := range fs { + gen.NewGeneratedFile(pp.adjustPath(f.Name), "").P(f.Content) + } + } + if pp.err != nil { gen.Error(pp.err) } diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor.go b/tool/internal_pkg/pluginmode/thriftgo/convertor.go index 5e42fa7448..51a8452354 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/convertor.go +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor.go @@ -24,6 +24,8 @@ import ( "regexp" "strings" + "github.com/cloudwego/kitex/tool/internal_pkg/util" + "github.com/cloudwego/thriftgo/generator/backend" "github.com/cloudwego/thriftgo/generator/golang" "github.com/cloudwego/thriftgo/parser" @@ -298,7 +300,6 @@ func (c *converter) convertTypes(req *plugin.Request) error { c.svc2ast = make(map[*generator.ServiceInfo]*parser.Thrift) for ast := range req.AST.DepthFirstSearch() { ref, pkg, pth := c.Utils.ParseNamespace(ast) - // make the current ast as an include to produce correct type references. fake := c.copyTreeWithRef(ast, ref) fake.Name2Category = nil @@ -316,13 +317,11 @@ func (c *converter) convertTypes(req *plugin.Request) error { return fmt.Errorf("build scope for fake ast '%s': %w", ast.Filename, err) } c.Utils.SetRootScope(scope) - pi := generator.PkgInfo{ PkgName: pkg, PkgRefName: pkg, - ImportPath: filepath.Join(c.Config.PackagePrefix, pth), + ImportPath: util.CombineOutputPath(c.Config.PackagePrefix, pth), } - for _, svc := range scope.Services() { si, err := c.makeService(pi, svc) if err != nil { @@ -352,7 +351,7 @@ func (c *converter) convertTypes(req *plugin.Request) error { } } // combine service - if c.Config.CombineService && len(ast.Services) > 0 { + if ast == req.AST && c.Config.CombineService && len(ast.Services) > 0 { var svcs []*generator.ServiceInfo var methods []*generator.MethodInfo for _, s := range all[ast.Filename] { diff --git a/tool/internal_pkg/pluginmode/thriftgo/patcher.go b/tool/internal_pkg/pluginmode/thriftgo/patcher.go index 1e9c8ea52b..10ef2f8f57 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/patcher.go +++ b/tool/internal_pkg/pluginmode/thriftgo/patcher.go @@ -176,17 +176,15 @@ func (p *patcher) patch(req *plugin.Request) (patches []*plugin.Generated, err e } p.utils.SetRootScope(scope) - namespace := ast.GetNamespaceOrReferenceName("go") - pkgName := p.utils.NamespaceToPackage(namespace) + pkgName := p.utils.RootScope().FilePackage() - path := p.utils.GetFilePath(ast) - full := filepath.Join(req.OutputPath, path) - dir, base := filepath.Split(full) - target := filepath.Join(dir, "k-"+base) + path := p.utils.CombineOutputPath(req.OutputPath, ast) + base := p.utils.GetFilename(ast) + target := filepath.Join(path, "k-"+base) // Define KitexUnusedProtection in k-consts.go . // Add k-consts.go before target to force the k-consts.go generated by consts.thrift to be renamed. - consts := filepath.Join(filepath.Dir(full), "k-consts.go") + consts := filepath.Join(path, "k-consts.go") if protection[consts] == nil { patch := &plugin.Generated{ Content: "package " + pkgName + "\n" + kitexUnusedProtection, @@ -226,7 +224,7 @@ func (p *patcher) patch(req *plugin.Request) (patches []*plugin.Generated, err e if err != nil { return nil, fmt.Errorf("read %q: %w", ast.Filename, err) } - path := filepath.Join(filepath.Dir(full), filepath.Base(ast.Filename)) + path := filepath.Join(path, filepath.Base(ast.Filename)) patches = append(patches, &plugin.Generated{ Content: string(content), Name: &path, diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index 8e79bf3f06..dfe717d724 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -85,6 +85,18 @@ func Run() int { files = append(files, fs...) } + if conv.Config.TemplateDir != "" { + if len(conv.Services) == 0 { + return conv.fail(errors.New("no service defined in the IDL")) + } + conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] + fs, err := gen.GenerateCustomPackage(&conv.Package) + if err != nil { + return conv.fail(err) + } + files = append(files, fs...) + } + res := &plugin.Response{ Warnings: conv.Warnings, } diff --git a/tool/internal_pkg/tpl/service.go b/tool/internal_pkg/tpl/service.go index 577b3519f5..d809acfd69 100644 --- a/tool/internal_pkg/tpl/service.go +++ b/tool/internal_pkg/tpl/service.go @@ -269,6 +269,10 @@ func (p *{{.ArgStructName}}) IsSetReq() bool { return p.Req != nil } +func (p *{{.ArgStructName}}) GetFirstArgument() interface{} { + return p.Req +} + type {{.ResStructName}} struct { Success {{.Resp.Type}} } @@ -328,6 +332,10 @@ func (p *{{.ResStructName}}) SetSuccess(x interface{}) { func (p *{{.ResStructName}}) IsSetSuccess() bool { return p.Success != nil } + +func (p *{{.ResStructName}}) GetResult() interface{} { + return p.Success +} {{- end}} {{end}} diff --git a/tool/internal_pkg/util/util.go b/tool/internal_pkg/util/util.go index 23ba8f11b5..b9243a3ca3 100644 --- a/tool/internal_pkg/util/util.go +++ b/tool/internal_pkg/util/util.go @@ -57,7 +57,7 @@ func FormatCode(code []byte) ([]byte, error) { func GetGOPATH() string { goPath := os.Getenv("GOPATH") // If there are many path in GOPATH, pick up the first one. - if GoPaths := strings.Split(goPath, ":"); len(GoPaths) > 1 { + if GoPaths := strings.Split(goPath, ":"); len(GoPaths) >= 1 { return GoPaths[0] } // GOPATH not set through environment variables, try to get one by executing "go env GOPATH" @@ -95,6 +95,29 @@ func LowerFirst(s string) string { return string(rs) } +// ReplaceString be used in string substitution. +func ReplaceString(s, old, new string, n int) string { + return strings.Replace(s, old, new, n) +} + +// SnakeString converts the string 's' to a snake string +func SnakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + for _, d := range []byte(s) { + if d >= 'A' && d <= 'Z' { + if j { + data = append(data, '_') + j = false + } + } else if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data)) +} + // UpperFirst converts the first letter to upper case for the given string. func UpperFirst(s string) string { rs := []rune(s) @@ -212,3 +235,22 @@ func RunGitCommand(gitLink string) (string, string, error) { return gitPath, "", nil } + +// CombineOutputPath read the output and path variables and render them into the final path +func CombineOutputPath(outputPath, ns string) string { + if ns != "" { + ns = strings.ReplaceAll(ns, ".", "/") + } + hasVarNs := strings.Contains(outputPath, "{namespace}") + hasVarNsUnderscore := strings.Contains(outputPath, "{namespaceUnderscore}") + if hasVarNs || hasVarNsUnderscore { + if hasVarNs { + outputPath = strings.ReplaceAll(outputPath, "{namespace}", ns) + } else if hasVarNsUnderscore { + outputPath = strings.ReplaceAll(outputPath, "{namespaceUnderscore}", strings.ReplaceAll(ns, "/", "_")) + } + } else { + outputPath = filepath.Join(outputPath, ns) + } + return outputPath +} diff --git a/tool/internal_pkg/util/util_test.go b/tool/internal_pkg/util/util_test.go new file mode 100644 index 0000000000..d107561775 --- /dev/null +++ b/tool/internal_pkg/util/util_test.go @@ -0,0 +1,34 @@ +// Copyright 2022 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestCombineOutputPath(t *testing.T) { + ns := "aaa.bbb.ccc" + path1 := "kitex_path/code" + output1 := CombineOutputPath(path1, ns) + test.Assert(t, output1 == "kitex_path/code/aaa/bbb/ccc") + path2 := "kitex_path/{namespace}/code" + output2 := CombineOutputPath(path2, ns) + test.Assert(t, output2 == "kitex_path/aaa/bbb/ccc/code") + path3 := "kitex_path/{namespaceUnderscore}/code" + output3 := CombineOutputPath(path3, ns) + test.Assert(t, output3 == "kitex_path/aaa_bbb_ccc/code") +} diff --git a/version.go b/version.go index a370167949..0d306b2791 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package kitex // Name and Version info of this framework, used for statistics and debug const ( Name = "Kitex" - Version = "v0.4.4" + Version = "v0.5.0" )