From 9f6d5178131fd30f1a97c221a4b347e71752e272 Mon Sep 17 00:00:00 2001 From: Joway Date: Tue, 6 Aug 2024 17:07:32 +0800 Subject: [PATCH 01/34] feat: stream v2 api chore: rm debug info feat: support meta handler fix: check stream mw is nil chore: fix unit test (#1545) feat: support multi service (#1538) feat: support stream recv send middleware feat: support long conn feat: support metainfo --- client/client.go | 36 +- client/client_streamx.go | 74 +++ client/service_inline.go | 2 +- client/streamclient/client_option.go | 2 +- client/streamxclient/client.go | 22 + client/streamxclient/client_gen.go | 74 +++ client/streamxclient/client_option.go | 51 ++ .../streamxcallopt/call_option.go | 25 + go.mod | 2 +- go.sum | 16 +- internal/client/option.go | 8 +- internal/server/option.go | 12 +- internal/server/register_option.go | 13 +- internal/server/remote_option.go | 3 - pkg/endpoint/endpoint.go | 2 +- pkg/remote/option.go | 2 + pkg/remote/remotesvr/server.go | 7 +- pkg/remote/trans/streamx/server_handler.go | 206 ++++++++ pkg/streamx/client_provider.go | 29 ++ pkg/streamx/client_provider_internal.go | 24 + pkg/streamx/provider/jsonrpc/client_option.go | 9 + .../provider/jsonrpc/client_provier.go | 47 ++ .../provider/jsonrpc/jsonrpc_gen_test.go | 161 +++++++ .../provider/jsonrpc/jsonrpc_impl_test.go | 68 +++ pkg/streamx/provider/jsonrpc/jsonrpc_test.go | 215 +++++++++ pkg/streamx/provider/jsonrpc/metadata.go | 4 + pkg/streamx/provider/jsonrpc/protocol.go | 167 +++++++ pkg/streamx/provider/jsonrpc/server_option.go | 9 + .../provider/jsonrpc/server_provider.go | 61 +++ pkg/streamx/provider/jsonrpc/stream.go | 122 +++++ pkg/streamx/provider/jsonrpc/transport.go | 153 ++++++ .../provider/jsonrpc/transport_test.go | 143 ++++++ .../provider/ttstream/client_option.go | 9 + .../provider/ttstream/client_provier.go | 67 +++ .../provider/ttstream/client_trans_pool.go | 105 ++++ pkg/streamx/provider/ttstream/frame.go | 130 +++++ pkg/streamx/provider/ttstream/frame_test.go | 51 ++ .../provider/ttstream/meta_frame_handler.go | 50 ++ pkg/streamx/provider/ttstream/metadata.go | 30 ++ pkg/streamx/provider/ttstream/mock_test.go | 49 ++ .../provider/ttstream/server_option.go | 9 + .../provider/ttstream/server_provider.go | 81 ++++ pkg/streamx/provider/ttstream/stream.go | 203 ++++++++ .../ttstream/stream_header_trailer.go | 30 ++ pkg/streamx/provider/ttstream/stream_io.go | 48 ++ pkg/streamx/provider/ttstream/transport.go | 261 ++++++++++ .../provider/ttstream/transport_test.go | 186 ++++++++ .../provider/ttstream/ttstream_client_test.go | 322 +++++++++++++ .../provider/ttstream/ttstream_common_test.go | 44 ++ .../ttstream/ttstream_gen_codec_test.go | 451 ++++++++++++++++++ .../ttstream/ttstream_gen_service_test.go | 205 ++++++++ .../provider/ttstream/ttstream_server_test.go | 81 ++++ pkg/streamx/server_provider.go | 51 ++ pkg/streamx/server_provider_internal.go | 25 + pkg/streamx/stream.go | 272 +++++++++++ pkg/streamx/stream_args.go | 104 ++++ pkg/streamx/stream_middleware.go | 48 ++ pkg/utils/contextmap/contextmap.go | 2 +- server/option_advanced_test.go | 9 +- server/option_test.go | 5 - server/server.go | 44 +- server/server_test.go | 4 +- server/service.go | 40 +- server/stream.go | 1 + server/streamxserver/server.go | 16 + server/streamxserver/server_gen.go | 100 ++++ server/streamxserver/server_option.go | 48 ++ transport/keys.go | 1 + 68 files changed, 4892 insertions(+), 59 deletions(-) create mode 100644 client/client_streamx.go create mode 100644 client/streamxclient/client.go create mode 100644 client/streamxclient/client_gen.go create mode 100644 client/streamxclient/client_option.go create mode 100644 client/streamxclient/streamxcallopt/call_option.go create mode 100644 pkg/remote/trans/streamx/server_handler.go create mode 100644 pkg/streamx/client_provider.go create mode 100644 pkg/streamx/client_provider_internal.go create mode 100644 pkg/streamx/provider/jsonrpc/client_option.go create mode 100644 pkg/streamx/provider/jsonrpc/client_provier.go create mode 100644 pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go create mode 100644 pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go create mode 100644 pkg/streamx/provider/jsonrpc/jsonrpc_test.go create mode 100644 pkg/streamx/provider/jsonrpc/metadata.go create mode 100644 pkg/streamx/provider/jsonrpc/protocol.go create mode 100644 pkg/streamx/provider/jsonrpc/server_option.go create mode 100644 pkg/streamx/provider/jsonrpc/server_provider.go create mode 100644 pkg/streamx/provider/jsonrpc/stream.go create mode 100644 pkg/streamx/provider/jsonrpc/transport.go create mode 100644 pkg/streamx/provider/jsonrpc/transport_test.go create mode 100644 pkg/streamx/provider/ttstream/client_option.go create mode 100644 pkg/streamx/provider/ttstream/client_provier.go create mode 100644 pkg/streamx/provider/ttstream/client_trans_pool.go create mode 100644 pkg/streamx/provider/ttstream/frame.go create mode 100644 pkg/streamx/provider/ttstream/frame_test.go create mode 100644 pkg/streamx/provider/ttstream/meta_frame_handler.go create mode 100644 pkg/streamx/provider/ttstream/metadata.go create mode 100644 pkg/streamx/provider/ttstream/mock_test.go create mode 100644 pkg/streamx/provider/ttstream/server_option.go create mode 100644 pkg/streamx/provider/ttstream/server_provider.go create mode 100644 pkg/streamx/provider/ttstream/stream.go create mode 100644 pkg/streamx/provider/ttstream/stream_header_trailer.go create mode 100644 pkg/streamx/provider/ttstream/stream_io.go create mode 100644 pkg/streamx/provider/ttstream/transport.go create mode 100644 pkg/streamx/provider/ttstream/transport_test.go create mode 100644 pkg/streamx/provider/ttstream/ttstream_client_test.go create mode 100644 pkg/streamx/provider/ttstream/ttstream_common_test.go create mode 100644 pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go create mode 100644 pkg/streamx/provider/ttstream/ttstream_gen_service_test.go create mode 100644 pkg/streamx/provider/ttstream/ttstream_server_test.go create mode 100644 pkg/streamx/server_provider.go create mode 100644 pkg/streamx/server_provider_internal.go create mode 100644 pkg/streamx/stream.go create mode 100644 pkg/streamx/stream_args.go create mode 100644 pkg/streamx/stream_middleware.go create mode 100644 server/streamxserver/server.go create mode 100644 server/streamxserver/server_gen.go create mode 100644 server/streamxserver/server_option.go diff --git a/client/client.go b/client/client.go index c68b2d386c..da41570c27 100644 --- a/client/client.go +++ b/client/client.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "github.com/cloudwego/kitex/pkg/streamx" "runtime" "runtime/debug" "strconv" @@ -70,8 +71,15 @@ type kClient struct { mws []endpoint.Middleware eps endpoint.Endpoint sEps endpoint.Endpoint - opt *client.Options - lbf *lbcache.BalancerFactory + + // streamx + sxEps endpoint.Endpoint + sxStreamMW streamx.StreamMiddleware + sxStreamRecvMW streamx.StreamRecvMiddleware + sxStreamSendMW streamx.StreamSendMiddleware + + opt *client.Options + lbf *lbcache.BalancerFactory inited bool closed bool @@ -91,12 +99,17 @@ func (kf *kcFinalizerClient) Call(ctx context.Context, method string, request, r // NewClient creates a kitex.Client with the given ServiceInfo, it is from generated code. func NewClient(svcInfo *serviceinfo.ServiceInfo, opts ...Option) (Client, error) { + nopts := client.NewOptions(opts) + return NewClientWithOptions(svcInfo, nopts) +} + +func NewClientWithOptions(svcInfo *serviceinfo.ServiceInfo, opts *Options) (Client, error) { if svcInfo == nil { return nil, errors.New("NewClient: no service info") } kc := &kcFinalizerClient{kClient: &kClient{}} kc.svcInfo = svcInfo - kc.opt = client.NewOptions(opts) + kc.opt = opts if err := kc.init(); err != nil { _ = kc.Close() return nil, err @@ -428,17 +441,30 @@ func (kc *kClient) richRemoteOption() { } func (kc *kClient) buildInvokeChain() error { + mwchain := endpoint.Chain(kc.mws...) + innerHandlerEp, err := kc.invokeHandleEndpoint() if err != nil { return err } - kc.eps = endpoint.Chain(kc.mws...)(innerHandlerEp) + kc.eps = mwchain(innerHandlerEp) innerStreamingEp, err := kc.invokeStreamingEndpoint() if err != nil { return err } - kc.sEps = endpoint.Chain(kc.mws...)(innerStreamingEp) + kc.sEps = mwchain(innerStreamingEp) + + // streamx NewStream + innerStreamXEp, err := kc.invokeStreamXEndpoint() + if err != nil { + return err + } + kc.sxEps = mwchain(innerStreamXEp) + // streamx stream call + kc.sxStreamMW = streamx.StreamMiddlewareChain(kc.opt.SMWs...) + kc.sxStreamRecvMW = streamx.StreamRecvMiddlewareChain(kc.opt.SRecvMWs...) + kc.sxStreamSendMW = streamx.StreamSendMiddlewareChain(kc.opt.SSendMWs...) return nil } diff --git a/client/client_streamx.go b/client/client_streamx.go new file mode 100644 index 0000000000..0a84ca3fe0 --- /dev/null +++ b/client/client_streamx.go @@ -0,0 +1,74 @@ +package client + +import ( + "context" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streamx" +) + +type StreamX interface { + NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) + Middlewares() (streamMW streamx.StreamMiddleware, recvMW streamx.StreamRecvMiddleware, sendMW streamx.StreamSendMiddleware) +} + +func (kc *kClient) Middlewares() (streamMW streamx.StreamMiddleware, recvMW streamx.StreamRecvMiddleware, sendMW streamx.StreamSendMiddleware) { + return kc.sxStreamMW, kc.sxStreamRecvMW, kc.sxStreamSendMW +} + +// return a bottom next function +// bottom next function will create a stream and change the streamx.Args +func (kc *kClient) invokeStreamXEndpoint() (endpoint.Endpoint, error) { + // TODO: implement trans handler layer and use trans factory + //transPipl, err := newCliTransHandler(kc.opt.RemoteOpt) + //if err != nil { + // return nil, err + //} + clientProvider, _ := kc.opt.RemoteOpt.Provider.(streamx.ClientProvider) + clientProvider = streamx.NewClientProvider(clientProvider) // wrap client provider + + return func(ctx context.Context, req, resp interface{}) (err error) { + ri := rpcinfo.GetRPCInfo(ctx) + cs, err := clientProvider.NewStream(ctx, ri) + if err != nil { + return err + } + streamArgs := resp.(streamx.StreamArgs) + // 此后的中间件才会有 Stream + streamx.AsMutableStreamArgs(streamArgs).SetStream(cs) + return nil + }, nil +} + +// NewStream create stream for streamx mode +func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) { + if !kc.inited { + panic("client not initialized") + } + if kc.closed { + panic("client is already closed") + } + if ctx == nil { + panic("ctx is nil") + } + var ri rpcinfo.RPCInfo + ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil) + + err := rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming) + if err != nil { + return nil, err + } + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + ctx = kc.opt.TracerCtl.DoStart(ctx, ri) + + streamArgs := streamx.NewStreamArgs(nil) + // put streamArgs into response arg + // it's an ugly trick but if we don't want to refactor too much, + // this is the only way to compatible with current endpoint design + err = kc.sxEps(ctx, req, streamArgs) + if err != nil { + return nil, err + } + return streamArgs.Stream().(streamx.ClientStream), nil +} diff --git a/client/service_inline.go b/client/service_inline.go index c807d749e3..da7c0c3003 100644 --- a/client/service_inline.go +++ b/client/service_inline.go @@ -130,7 +130,7 @@ func (kc *serviceInlineClient) Call(ctx context.Context, method string, request, } kc.opt.TracerCtl.DoFinish(ctx, ri, reportErr) // If the user start a new goroutine and return before endpoint finished, it may cause panic. - // For example,, if the user writes a timeout Middleware and times out, rpcinfo will be recycled, + // For example,, if the user writes a timeout StreamMiddleware and times out, rpcinfo will be recycled, // but in fact, rpcinfo is still being used when it is executed inside // So if endpoint returns err, client won't recycle rpcinfo. if reportErr == nil { diff --git a/client/streamclient/client_option.go b/client/streamclient/client_option.go index 95f161ac49..0cb6ba94a7 100644 --- a/client/streamclient/client_option.go +++ b/client/streamclient/client_option.go @@ -42,7 +42,7 @@ func WithSuite(suite client.Suite) Option { // WithMiddleware adds middleware for client to handle request. // NOTE: for streaming APIs (bidirectional, client, server), req is not valid, resp is *streaming.Result -// If you want to intercept recv/send calls, please use Recv/Send Middleware +// If you want to intercept recv/send calls, please use Recv/Send StreamMiddleware func WithMiddleware(mw endpoint.Middleware) Option { return ConvertOptionFrom(client.WithMiddleware(mw)) } diff --git a/client/streamxclient/client.go b/client/streamxclient/client.go new file mode 100644 index 0000000000..d477751d48 --- /dev/null +++ b/client/streamxclient/client.go @@ -0,0 +1,22 @@ +package streamxclient + +import ( + "github.com/cloudwego/kitex/client" + iclient "github.com/cloudwego/kitex/internal/client" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +type Client = client.StreamX + +func NewClient(svcInfo *serviceinfo.ServiceInfo, opts ...Option) (Client, error) { + iopts := make([]client.Option, 0, len(opts)+1) + for _, opt := range opts { + iopts = append(iopts, convertClientOption(opt)) + } + nopts := iclient.NewOptions(iopts) + c, err := client.NewClientWithOptions(svcInfo, nopts) + if err != nil { + return nil, err + } + return c.(client.StreamX), nil +} diff --git a/client/streamxclient/client_gen.go b/client/streamxclient/client_gen.go new file mode 100644 index 0000000000..13ed0f13f1 --- /dev/null +++ b/client/streamxclient/client_gen.go @@ -0,0 +1,74 @@ +package streamxclient + +import ( + "context" + + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" +) + +func InvokeStream[Header, Trailer, Req, Res any]( + ctx context.Context, cli client.StreamX, smode serviceinfo.StreamingMode, method string, + req *Req, res *Res, callOptions ...streamxcallopt.CallOption, +) (stream *streamx.GenericClientStream[Header, Trailer, Req, Res], err error) { + reqArgs, resArgs := streamx.NewStreamReqArgs(nil), streamx.NewStreamResArgs(nil) + streamArgs := streamx.NewStreamArgs(nil) + // important notes: please don't set a typed nil value into interface arg like NewStreamReqArgs({typ: *Res, ptr: nil}) + // otherwise, reqArgs.Req() may not ==nil forever + if req != nil { + reqArgs.SetReq(req) + } + if res != nil { + resArgs.SetRes(res) + } + + cs, err := cli.NewStream(ctx, method, req, callOptions...) + if err != nil { + return nil, err + } + stream = streamx.NewGenericClientStream[Header, Trailer, Req, Res](cs) + streamx.AsMutableStreamArgs(streamArgs).SetStream(stream) + + streamMW, recvMW, sendMW := cli.Middlewares() + stream.SetStreamRecvMiddleware(recvMW) + stream.SetStreamSendMiddleware(sendMW) + + streamInvoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + // assemble streaming args depend on each stream mode + switch smode { + case serviceinfo.StreamingUnary: + if err = stream.SendMsg(ctx, req); err != nil { + return err + } + if err = stream.RecvMsg(ctx, res); err != nil { + return err + } + resArgs.SetRes(res) + if err = stream.CloseSend(ctx); err != nil { + return err + } + case serviceinfo.StreamingClient: + case serviceinfo.StreamingServer: + if err = stream.SendMsg(ctx, req); err != nil { + return err + } + if err = stream.CloseSend(ctx); err != nil { + return err + } + case serviceinfo.StreamingBidirectional: + default: + } + return nil + } + if streamMW != nil { + err = streamMW(streamInvoke)(ctx, streamArgs, reqArgs, resArgs) + } else { + err = streamInvoke(ctx, streamArgs, reqArgs, resArgs) + } + if err != nil { + return nil, err + } + return stream, nil +} diff --git a/client/streamxclient/client_option.go b/client/streamxclient/client_option.go new file mode 100644 index 0000000000..ed9103fff4 --- /dev/null +++ b/client/streamxclient/client_option.go @@ -0,0 +1,51 @@ +package streamxclient + +import ( + "github.com/cloudwego/kitex/client" + internal_client "github.com/cloudwego/kitex/internal/client" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/utils" +) + +type Option internal_client.Option +type Options = internal_client.Options + +func WithHostPorts(hostports ...string) Option { + return convertInternalClientOption(client.WithHostPorts(hostports...)) +} + +func WithDestService(destService string) Option { + return convertInternalClientOption(client.WithDestService(destService)) +} + +func WithProvider(pvd streamx.ClientProvider) Option { + return Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.RemoteOpt.Provider = pvd + }} +} + +func WithStreamMiddleware(smw streamx.StreamMiddleware) Option { + return Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.SMWs = append(o.SMWs, smw) + }} +} + +func WithStreamRecvMiddleware(smw streamx.StreamRecvMiddleware) Option { + return Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.SRecvMWs = append(o.SRecvMWs, smw) + }} +} + +func WithStreamSendMiddleware(smw streamx.StreamSendMiddleware) Option { + return Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.SSendMWs = append(o.SSendMWs, smw) + }} +} + +func convertInternalClientOption(o internal_client.Option) Option { + return Option{F: o.F} +} + +func convertClientOption(o Option) internal_client.Option { + return internal_client.Option{F: o.F} +} diff --git a/client/streamxclient/streamxcallopt/call_option.go b/client/streamxclient/streamxcallopt/call_option.go new file mode 100644 index 0000000000..2ef16ce3e3 --- /dev/null +++ b/client/streamxclient/streamxcallopt/call_option.go @@ -0,0 +1,25 @@ +package streamxcallopt + +import ( + "fmt" + "strings" + "time" +) + +type CallOptions struct { + rpcTimeout time.Duration + ProviderOption any +} + +type CallOption struct { + f func(o *CallOptions, di *strings.Builder) +} + +type WithCallOption func(o *CallOption) + +func WithRPCTimeout(rpcTimeout time.Duration) CallOption { + return CallOption{f: func(o *CallOptions, di *strings.Builder) { + di.WriteString(fmt.Sprintf("WithRPCTimeout(%d)", rpcTimeout)) + o.rpcTimeout = rpcTimeout + }} +} diff --git a/go.mod b/go.mod index e19ce92b23..c996b007a5 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/cloudwego/dynamicgo v0.4.6-0.20241115162834-0e99bc39b128 github.com/cloudwego/fastpb v0.0.5 github.com/cloudwego/frugal v0.2.0 - github.com/cloudwego/gopkg v0.1.2 + github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4 github.com/cloudwego/localsession v0.1.1 github.com/cloudwego/netpoll v0.6.4 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index ebec916074..0e40c50be3 100644 --- a/go.sum +++ b/go.sum @@ -22,14 +22,26 @@ github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZ github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= -github.com/cloudwego/gopkg v0.1.2 h1:650t+RiZGht8qX+y0hl49JXJCuO44GhbGZuxDzr2PyI= -github.com/cloudwego/gopkg v0.1.2/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= +github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4 h1:SHw9GUBBcAnLWeK2MtPH7O6YQG9Q2ZZ8koD/4alpLvE= +github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +<<<<<<< HEAD github.com/cloudwego/localsession v0.1.1 h1:tbK7laDVrYfFDXoBXo4uCGMAxU4qmz2dDm8d4BGBnDo= github.com/cloudwego/localsession v0.1.1/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= github.com/cloudwego/netpoll v0.6.4 h1:z/dA4sOTUQof6zZIO4QNnLBXsDFFFEos9OOGloR6kno= github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= +======= +github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= +github.com/cloudwego/localsession v0.0.2/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= +github.com/cloudwego/netpoll v0.6.5-0.20240905095957-e6ec47be2fe0 h1:2aoCxK8fee7LhwWveg3ORVEDBoMtmTY2NuSAtNGpnFI= +github.com/cloudwego/netpoll v0.6.5-0.20240905095957-e6ec47be2fe0/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= +<<<<<<< HEAD +>>>>>>> 9e44721 (feat: support multi service (#1538)) +======= +github.com/cloudwego/netpoll v0.6.5-0.20240911073319-2ec9568b10cf h1:c/K4XrkloCgZp+En3LjbXtqfr0KQwC85utUvdDm76V4= +github.com/cloudwego/netpoll v0.6.5-0.20240911073319-2ec9568b10cf/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= +>>>>>>> 968bdfc (chore: fix unit test (#1545)) github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= github.com/cloudwego/runtimex v0.1.0/go.mod h1:23vL/HGV0W8nSCHbe084AgEBdDV4rvXenEUMnUNvUd8= github.com/cloudwego/thriftgo v0.3.18 h1:gnr1vz7G3RbwwCK9AMKHZf63VYGa7ene6WbI9VrBJSw= diff --git a/internal/client/option.go b/internal/client/option.go index 5d3b30635a..968b379b82 100644 --- a/internal/client/option.go +++ b/internal/client/option.go @@ -19,6 +19,7 @@ package client import ( "context" + "github.com/cloudwego/kitex/pkg/streamx" "time" "github.com/cloudwego/localsession/backup" @@ -82,8 +83,11 @@ type Options struct { ACLRules []acl.RejectFunc - MWBs []endpoint.MiddlewareBuilder - IMWBs []endpoint.MiddlewareBuilder + MWBs []endpoint.MiddlewareBuilder + IMWBs []endpoint.MiddlewareBuilder + SMWs []streamx.StreamMiddleware + SRecvMWs []streamx.StreamRecvMiddleware + SSendMWs []streamx.StreamSendMiddleware Bus event.Bus Events event.Queue diff --git a/internal/server/option.go b/internal/server/option.go index b540f03ce4..82138d139c 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -23,8 +23,6 @@ import ( "os/signal" "syscall" - "github.com/cloudwego/localsession/backup" - "github.com/cloudwego/kitex/internal/configutil" "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/pkg/acl" @@ -43,8 +41,10 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" + "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/localsession/backup" ) func init() { @@ -80,6 +80,10 @@ type Options struct { Limit Limit MWBs []endpoint.MiddlewareBuilder + // streamx + SMWs []streamx.StreamMiddleware + SRecvMWs []streamx.StreamRecvMiddleware + SSendMWs []streamx.StreamSendMiddleware Bus event.Bus Events event.Queue @@ -177,9 +181,7 @@ func DefaultSupportedTransportsFunc(option remote.ServerOption) []string { if factory, ok := option.SvrHandlerFactory.(trans.MuxEnabledFlag); ok { if factory.MuxEnabled() { return []string{"ttheader_mux"} - } else { - return []string{"ttheader", "framed", "ttheader_framed", "grpc"} } } - return nil + return []string{"ttheader", "framed", "ttheader_framed", "grpc"} } diff --git a/internal/server/register_option.go b/internal/server/register_option.go index 76233eae74..6cf884ac70 100644 --- a/internal/server/register_option.go +++ b/internal/server/register_option.go @@ -16,7 +16,10 @@ package server -import "github.com/cloudwego/kitex/pkg/endpoint" +import ( + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/streamx" +) // RegisterOption is the only way to config service registration. type RegisterOption struct { @@ -25,8 +28,12 @@ type RegisterOption struct { // RegisterOptions is used to config service registration. type RegisterOptions struct { - IsFallbackService bool - Middlewares []endpoint.Middleware + IsFallbackService bool + Middlewares []endpoint.Middleware + StreamMiddlewares []streamx.StreamMiddleware + StreamRecvMiddlewares []streamx.StreamRecvMiddleware + StreamSendMiddlewares []streamx.StreamSendMiddleware + Provider streamx.ServerProvider } // NewRegisterOptions creates a register options. diff --git a/internal/server/remote_option.go b/internal/server/remote_option.go index a814df152c..1fc22d71de 100644 --- a/internal/server/remote_option.go +++ b/internal/server/remote_option.go @@ -23,16 +23,13 @@ package server import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" - "github.com/cloudwego/kitex/pkg/remote/trans/detection" "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" - "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" ) func newServerRemoteOption() *remote.ServerOption { return &remote.ServerOption{ TransServerFactory: netpoll.NewTransServerFactory(), - SvrHandlerFactory: detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()), Codec: codec.NewDefaultCodec(), Address: defaultAddress, ExitWaitTime: defaultExitWaitTime, diff --git a/pkg/endpoint/endpoint.go b/pkg/endpoint/endpoint.go index 0c0aff1c14..b5912d279c 100644 --- a/pkg/endpoint/endpoint.go +++ b/pkg/endpoint/endpoint.go @@ -22,7 +22,7 @@ import "context" type Endpoint func(ctx context.Context, req, resp interface{}) (err error) // Middleware deal with input Endpoint and output Endpoint. -type Middleware func(Endpoint) Endpoint +type Middleware func(next Endpoint) Endpoint // MiddlewareBuilder builds a middleware with information from a context. type MiddlewareBuilder func(ctx context.Context) Middleware diff --git a/pkg/remote/option.go b/pkg/remote/option.go index 71e81c2500..93a7ec416d 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -141,4 +141,6 @@ type ClientOption struct { Option EnableConnPoolReporter bool + + Provider interface{} // streamx.ClientProvider } diff --git a/pkg/remote/remotesvr/server.go b/pkg/remote/remotesvr/server.go index 81ee1f9c94..05c566976b 100644 --- a/pkg/remote/remotesvr/server.go +++ b/pkg/remote/remotesvr/server.go @@ -38,8 +38,6 @@ type server struct { opt *remote.ServerOption listener net.Listener transSvr remote.TransServer - - inkHdlFunc endpoint.Endpoint sync.Mutex } @@ -47,9 +45,8 @@ type server struct { func NewServer(opt *remote.ServerOption, inkHdlFunc endpoint.Endpoint, transHdlr remote.ServerTransHandler) (Server, error) { transSvr := opt.TransServerFactory.NewTransServer(opt, transHdlr) s := &server{ - opt: opt, - inkHdlFunc: inkHdlFunc, - transSvr: transSvr, + opt: opt, + transSvr: transSvr, } return s, nil } diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go new file mode 100644 index 0000000000..7e97ed0955 --- /dev/null +++ b/pkg/remote/trans/streamx/server_handler.go @@ -0,0 +1,206 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamx + +import ( + "context" + "errors" + "io" + "log" + "net" + "runtime/debug" + + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streamx" +) + +/* 实际上 remote.ServerTransHandler 真正被 trans_server.go 使用的接口只有: +- OnRead +- OnActive +- OnInactive +- OnError +- GracefulShutdown: assert 方式使用 + +其他接口实际上最终是用来去组装了 transpipeline .... +*/ + +type svrTransHandlerFactory struct { + provider streamx.ServerProvider +} + +// NewSvrTransHandlerFactory ... +func NewSvrTransHandlerFactory(provider streamx.ServerProvider) remote.ServerTransHandlerFactory { + sp := streamx.NewServerProvider(provider) // wrapped server provider + return &svrTransHandlerFactory{provider: sp} +} + +func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { + return &svrTransHandler{ + opt: opt, + provider: f.provider, + }, nil +} + +var _ remote.ServerTransHandler = &svrTransHandler{} +var errProtocolNotMatch = errors.New("protocol not match") + +type svrTransHandler struct { + opt *remote.ServerOption + provider streamx.ServerProvider + inkHdlFunc endpoint.Endpoint +} + +func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) { + t.inkHdlFunc = inkHdlFunc +} + +func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) (err error) { + if t.provider.Available(ctx, conn) { + return nil + } + return errProtocolNotMatch +} + +func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { + var err error + ctx, err = t.provider.OnActive(ctx, conn) + if err != nil { + return nil, err + } + return ctx, nil +} + +func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { + // connection level goroutine + for { + nctx, ss, nerr := t.provider.OnStream(ctx, conn) + if nerr != nil { + if !errors.Is(nerr, io.EOF) { + klog.CtxErrorf(ctx, "KITEX: OnStream failed: err=%v", nerr) + } + return nerr + } + // stream level goroutine + go func() { + nerr = t.OnStream(nctx, conn, ss) + if nerr != nil { + if !errors.Is(nerr, io.EOF) { + klog.CtxErrorf(ctx, "KITEX: stream ReadStream failed: err=%v", nerr) + } + return + } + }() + } +} + +// OnStream +// - create server stream +// - process server stream +// - close server stream +func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss streamx.ServerStream) (err error) { + // inkHdlFunc 包含了所有中间件 + 用户 serviceInfo.methodHandler + // 这里 streamx 依然会复用原本的 server endpoint.Endpoint 中间件,因为他们都不会单独去取 req/res 的值 + // 无法在保留现有 streaming 功能的情况下,彻底弃用 endpoint.Endpoint , 所以这里依然使用 endpoint 接口 + // 但是对用户 API ,做了单独的封装。把这部分脏逻辑仅暴露在框架中。 + sargs := streamx.NewStreamArgs(ss) + ctx = streamx.WithStreamArgsContext(ctx, sargs) + + ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr()) + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + defer func() { + if rpcinfo.PoolEnabled() { + ri = t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr()) + // TODO: rpcinfo pool + } + }() + + ink := ri.Invocation().(rpcinfo.InvocationSetter) + ink.SetServiceName(ss.Service()) + ink.SetMethodName(ss.Method()) + if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil { + _ = mutableTo.SetMethod(ss.Method()) + } + //_ = rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(transport.JSONRPC) + + ctx = t.startTracer(ctx, ri) + defer func() { + if err != nil { + log.Println("OnStream failed: ", err) + } + panicErr := recover() + if panicErr != nil { + if conn != nil { + klog.CtxErrorf(ctx, "KITEX: streamx panic happened, close conn, remoteAddress=%s, error=%s\nstack=%s", conn.RemoteAddr(), panicErr, string(debug.Stack())) + } else { + klog.CtxErrorf(ctx, "KITEX: streamx panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack())) + } + } + t.finishTracer(ctx, ri, err, panicErr) + }() + + reqArgs := streamx.NewStreamReqArgs(nil) + resArgs := streamx.NewStreamResArgs(nil) + // server handler (which will call streamxserver.InvokeStream inside) + serr := t.inkHdlFunc(ctx, reqArgs, resArgs) + ctx, err = t.provider.OnStreamFinish(ctx, ss) + if err == nil && serr != nil { + err = serr + } + return err +} + +func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, send remote.Message) (nctx context.Context, err error) { + return ctx, nil +} + +func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) { + return ctx, nil +} + +func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { + _, _ = t.provider.OnInactive(ctx, conn) +} + +func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { +} + +func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) { + return ctx, nil +} + +func (t *svrTransHandler) SetPipeline(pipeline *remote.TransPipeline) { +} + +func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context { + c := t.opt.TracerCtl.DoStart(ctx, ri) + return c +} + +func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, err error, panicErr interface{}) { + rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) + if rpcStats == nil { + return + } + if panicErr != nil { + rpcStats.SetPanicked(panicErr) + } + t.opt.TracerCtl.DoFinish(ctx, ri, err) + rpcStats.Reset() +} diff --git a/pkg/streamx/client_provider.go b/pkg/streamx/client_provider.go new file mode 100644 index 0000000000..0f473e085b --- /dev/null +++ b/pkg/streamx/client_provider.go @@ -0,0 +1,29 @@ +package streamx + +import ( + "context" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +/* Hot it works + +clientProvider := xxx.NewClientProvider(xxx.WithXXX(...)) +client := {user_gencode}.NewClient({kitex_client}.WithClientProvider(clientProvider)) + => {kitex_client}.NewClient(...) + => {kitex_client}.initMiddlewares() + +stream := client.ServerStreamMethod(...) + => {kitex_client}.NewStream(...) + => {kitex_client}.internalProvider.NewStream(...) : run middlewares + => clientProvider.NewStream(...) + +res := stream.Recv(...) + => {kitex_client}.internalProvider.Stream.Recv(...) : run middlewares + => clientProvider.Stream.Recv(...) +*/ + +type ClientProvider interface { + // NewStream create a stream based on rpcinfo and callOptions + NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (ClientStream, error) +} diff --git a/pkg/streamx/client_provider_internal.go b/pkg/streamx/client_provider_internal.go new file mode 100644 index 0000000000..53d9968734 --- /dev/null +++ b/pkg/streamx/client_provider_internal.go @@ -0,0 +1,24 @@ +package streamx + +import ( + "context" + + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +func NewClientProvider(cs ClientProvider) ClientProvider { + return internalClientProvider{ClientProvider: cs} +} + +type internalClientProvider struct { + ClientProvider +} + +func (p internalClientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (ClientStream, error) { + cs, err := p.ClientProvider.NewStream(ctx, ri, callOptions...) + if err != nil { + return nil, err + } + return cs, nil +} diff --git a/pkg/streamx/provider/jsonrpc/client_option.go b/pkg/streamx/provider/jsonrpc/client_option.go new file mode 100644 index 0000000000..dc434cd382 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/client_option.go @@ -0,0 +1,9 @@ +package jsonrpc + +type ClientProviderOption func(cp *clientProvider) + +func WithClientPayloadLimit(limit int) ClientProviderOption { + return func(cp *clientProvider) { + cp.payloadLimit = limit + } +} diff --git a/pkg/streamx/provider/jsonrpc/client_provier.go b/pkg/streamx/provider/jsonrpc/client_provier.go new file mode 100644 index 0000000000..aa4e9fec6c --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/client_provier.go @@ -0,0 +1,47 @@ +package jsonrpc + +import ( + "context" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + "net" +) + +var _ streamx.ClientProvider = (*clientProvider)(nil) + +func NewClientProvider(sinfo *serviceinfo.ServiceInfo, opts ...ClientProviderOption) (streamx.ClientProvider, error) { + cp := new(clientProvider) + cp.sinfo = sinfo + for _, opt := range opts { + opt(cp) + } + return cp, nil +} + +type clientProvider struct { + sinfo *serviceinfo.ServiceInfo + payloadLimit int +} + +func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) { + invocation := ri.Invocation() + method := invocation.MethodName() + addr := ri.To().Address() + if addr == nil { + return nil, kerrors.ErrNoDestAddress + } + conn, err := net.Dial(addr.Network(), addr.String()) + if err != nil { + return nil, err + } + trans := newTransport(c.sinfo, conn) + s, err := trans.newStream(method) + if err != nil { + return nil, err + } + cs := newClientStream(s) + return cs, err +} diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go new file mode 100644 index 0000000000..bf0bf56394 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go @@ -0,0 +1,161 @@ +package jsonrpc_test + +import ( + "context" + + "github.com/cloudwego/kitex/client/streamxclient" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/jsonrpc" + "github.com/cloudwego/kitex/server/streamxserver" +) + +// === gen code === + +type ClientStreamingServer[Req, Res any] streamx.ClientStreamingServer[jsonrpc.Header, jsonrpc.Trailer, Req, Res] +type ServerStreamingServer[Res any] streamx.ServerStreamingServer[jsonrpc.Header, jsonrpc.Trailer, Res] +type BidiStreamingServer[Req, Res any] streamx.BidiStreamingServer[jsonrpc.Header, jsonrpc.Trailer, Req, Res] +type ClientStreamingClient[Req, Res any] streamx.ClientStreamingClient[jsonrpc.Header, jsonrpc.Trailer, Req, Res] +type ServerStreamingClient[Res any] streamx.ServerStreamingClient[jsonrpc.Header, jsonrpc.Trailer, Res] +type BidiStreamingClient[Req, Res any] streamx.BidiStreamingClient[jsonrpc.Header, jsonrpc.Trailer, Req, Res] + +var serviceInfo = &serviceinfo.ServiceInfo{ + ServiceName: "a.b.c", + Methods: map[string]serviceinfo.MethodInfo{ + "Unary": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + ctx, serviceinfo.StreamingUnary, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingUnary), + ), + "ClientStream": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + ctx, serviceinfo.StreamingClient, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingClient), + ), + "ServerStream": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + ctx, serviceinfo.StreamingServer, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingServer), + ), + "BidiStream": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), + }, + Extra: map[string]interface{}{"streaming": true}, +} + +func NewClient(destService string, opts ...streamxclient.Option) (ClientInterface, error) { + var options []streamxclient.Option + options = append(options, streamxclient.WithDestService(destService)) + options = append(options, opts...) + cp, err := jsonrpc.NewClientProvider(serviceInfo) + if err != nil { + return nil, err + } + options = append(options, streamxclient.WithProvider(cp)) + cli, err := streamxclient.NewClient(serviceInfo, options...) + if err != nil { + return nil, err + } + kc := &kClient{Client: cli} + return kc, nil +} + +func NewServer(handler ServerInterface, opts ...streamxserver.Option) (streamxserver.Server, error) { + var options []streamxserver.Option + options = append(options, opts...) + sp, err := jsonrpc.NewServerProvider(serviceInfo) + if err != nil { + return nil, err + } + svr := streamxserver.NewServer(options...) + if err := svr.RegisterService(serviceInfo, handler, streamxserver.WithProvider(sp)); err != nil { + return nil, err + } + return svr, nil +} + +type Request struct { + Type int32 `json:"Type"` + Message string `json:"Message"` +} + +type Response struct { + Type int32 `json:"Type"` + Message string `json:"Message"` +} + +type ServerInterface interface { + Unary(ctx context.Context, req *Request) (*Response, error) + ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (*Response, error) + ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error + BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error +} + +type ClientInterface interface { + Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) + ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream ClientStreamingClient[Request, Response], err error) + ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( + stream ServerStreamingClient[Response], err error) + BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream BidiStreamingClient[Request, Response], err error) +} + +// --- Define Client Implementation --- + +var _ ClientInterface = (*kClient)(nil) + +type kClient struct { + streamxclient.Client +} + +func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { + res := new(Response) + _, err := streamxclient.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + ctx, c.Client, serviceinfo.StreamingUnary, "Unary", req, res, callOptions...) + if err != nil { + return nil, err + } + return res, nil +} + +func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream ClientStreamingClient[Request, Response], err error) { + return streamxclient.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + ctx, c.Client, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) +} + +func (c *kClient) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( + stream ServerStreamingClient[Response], err error) { + return streamxclient.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + ctx, c.Client, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) +} + +func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream BidiStreamingClient[Request, Response], err error) { + return streamxclient.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + ctx, c.Client, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) +} diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go new file mode 100644 index 0000000000..b4019c9575 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go @@ -0,0 +1,68 @@ +package jsonrpc_test + +import ( + "context" + "io" + "log" +) + +type serviceImpl struct{} + +func (si *serviceImpl) Unary(ctx context.Context, req *Request) (*Response, error) { + resp := &Response{Message: req.Message} + log.Printf("Server Unary: req={%v} resp={%v}", req, resp) + return resp, nil +} + +func (si *serviceImpl) ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (res *Response, err error) { + var msg string + defer log.Printf("Server ClientStream end") + for { + req, err := stream.Recv(ctx) + if err == io.EOF { + res = new(Response) + res.Message = msg + return res, nil + } + if err != nil { + return nil, err + } + msg = req.Message + log.Printf("Server ClientStream: req={%v}", req) + } +} + +func (si *serviceImpl) ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error { + log.Printf("Server ServerStream: req={%v}", req) + for i := 0; i < 3; i++ { + resp := new(Response) + resp.Type = int32(i) + resp.Message = req.Message + err := stream.Send(ctx, resp) + if err != nil { + return err + } + log.Printf("Server ServerStream: resp={%v}", resp) + } + return nil +} + +func (si *serviceImpl) BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error { + for { + req, err := stream.Recv(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + + resp := new(Response) + resp.Message = req.Message + err = stream.Send(ctx, resp) + if err != nil { + return err + } + log.Printf("Server BidiStream: req={%v} resp={%v}", req, resp) + } +} diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_test.go new file mode 100644 index 0000000000..f7dc9e7f65 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/jsonrpc_test.go @@ -0,0 +1,215 @@ +package jsonrpc_test + +import ( + "context" + "errors" + "io" + "log" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cloudwego/kitex/client/streamxclient" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/jsonrpc" + "github.com/cloudwego/kitex/server/streamxserver" + "github.com/cloudwego/netpoll" +) + +func TestJSONRPC(t *testing.T) { + var addr = test.GetLocalAddress() + ln, err := netpoll.CreateListener("tcp", addr) + test.Assert(t, err == nil, err) + + // create server + var serverStreamCount int32 + waitServerStreamDone := func() { + for atomic.LoadInt32(&serverStreamCount) != 0 { + t.Logf("waitServerStreamDone: %d", atomic.LoadInt32(&serverStreamCount)) + time.Sleep(time.Millisecond * 100) + } + } + methodCount := map[string]int{} + sp, err := jsonrpc.NewServerProvider(serviceInfo) + test.Assert(t, err == nil, err) + svr := streamxserver.NewServer(streamxserver.WithListener(ln)) + err = svr.RegisterService(serviceInfo, new(serviceImpl), + streamxserver.WithProvider(sp), + streamxserver.WithStreamMiddleware( + // middleware example: server streaming mode + func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + log.Printf("Server middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs) + test.Assert(t, streamArgs.Stream() != nil) + + switch streamArgs.Stream().Mode() { + case streamx.StreamingUnary: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() != nil) + case streamx.StreamingClient: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() != nil) + case streamx.StreamingServer: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingBidirectional: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + } + test.Assert(t, err == nil, err) + methodCount[streamArgs.Stream().Method()]++ + + log.Printf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) + atomic.AddInt32(&serverStreamCount, 1) + return nil + } + }, + ), + ) + test.Assert(t, err == nil, err) + go func() { + err := svr.Run() + test.Assert(t, err == nil, err) + }() + defer svr.Stop() + time.Sleep(time.Millisecond * 100) + + // create client + ctx := context.Background() + cli, err := NewClient( + "a.b.c", + streamxclient.WithHostPorts(addr), + streamxclient.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { + return func(ctx context.Context, stream streamx.Stream, res any) (err error) { + return next(ctx, stream, res) + } + }), + streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + log.Printf("Client middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) + err = next(ctx, streamArgs, reqArgs, resArgs) + log.Printf("Client middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) + + test.Assert(t, streamArgs.Stream() != nil) + switch streamArgs.Stream().Mode() { + case streamx.StreamingUnary: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() != nil) + case streamx.StreamingClient: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingServer: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingBidirectional: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + } + test.Assert(t, err == nil, err) + return err + } + }), + ) + test.Assert(t, err == nil, err) + + t.Logf("=== Unary ===") + req := new(Request) + req.Message = "Unary" + res, err := cli.Unary(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, req.Message == res.Message, res.Message) + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + + // client stream + t.Logf("=== ClientStream ===") + cs, err := cli.ClientStream(ctx) + test.Assert(t, err == nil, err) + for i := 0; i < 3; i++ { + req := new(Request) + req.Type = int32(i) + req.Message = "ClientStream" + err = cs.Send(ctx, req) + test.Assert(t, err == nil, err) + } + res, err = cs.CloseAndRecv(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == "ClientStream", res.Message) + t.Logf("Client ClientStream CloseAndRecv: %v", res) + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + + // server stream + t.Logf("=== ServerStream ===") + req = new(Request) + req.Message = "ServerStream" + ss, err := cli.ServerStream(ctx, req) + test.Assert(t, err == nil, err) + for { + res, err := ss.Recv(ctx) + if errors.Is(err, io.EOF) { + break + } + test.Assert(t, err == nil, err) + t.Logf("Client ServerStream recv: %v", res) + } + //err = ss.CloseSend(ctx) + //test.Assert(t, err == nil, err) + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + + // bidi stream + t.Logf("=== BidiStream ===") + bs, err := cli.BidiStream(ctx) + test.Assert(t, err == nil, err) + round := 5 + msg := "BidiStream" + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < round; i++ { + req := new(Request) + req.Message = msg + err := bs.Send(ctx, req) + test.Assert(t, err == nil, err) + } + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + }() + go func() { + defer wg.Done() + i := 0 + for { + res, err := bs.Recv(ctx) + if errors.Is(err, io.EOF) { + break + } + i++ + test.Assert(t, err == nil, err) + test.Assert(t, msg == res.Message, res.Message) + } + test.Assert(t, i == round, i) + }() + wg.Wait() + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() +} diff --git a/pkg/streamx/provider/jsonrpc/metadata.go b/pkg/streamx/provider/jsonrpc/metadata.go new file mode 100644 index 0000000000..595227a6c7 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/metadata.go @@ -0,0 +1,4 @@ +package jsonrpc + +type Header map[string]string +type Trailer map[string]string diff --git a/pkg/streamx/provider/jsonrpc/protocol.go b/pkg/streamx/provider/jsonrpc/protocol.go new file mode 100644 index 0000000000..a6a1d29523 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/protocol.go @@ -0,0 +1,167 @@ +package jsonrpc + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + + "github.com/cloudwego/netpoll" +) + +/* JSON RPC Protocol + +=== client create a new stream === +- send {type=META, sid=1, service="a.b.c", method="xxx"} +- send {type=DATA, sid=1, service="a.b.c", method="xxx", payload="..."} +- recv {type=DATA, sid=1, service="a.b.c", method="xxx", payload="..."} + +=== server accept a new stream === +- recv {type=META, sid=1, service="a.b.c", method="xxx"} +- recv {type=DATA, sid=1, service="a.b.c", method="xxx", payload="..."} +- send {type=DATA, sid=1, service="a.b.c", method="xxx", payload="..."} + +=== client try to close stream === +- client send {type=EOF , sid=1, service="a.b.c", method="xxx"} +- server recv {type=EOF , sid=1, service="a.b.c", method="xxx"} +- server send {type=EOF , sid=1, service="a.b.c", method="xxx"}: server close stream +- client recv {type=EOF , sid=1, service="a.b.c", method="xxx"}: client close stream + +=== server try to close stream === +- server send {type=EOF , sid=1, service="a.b.c", method="xxx"} +- client recv {type=EOF , sid=1, service="a.b.c", method="xxx"} +- client send {type=EOF , sid=1, service="a.b.c", method="xxx"}: client close stream +- server recv {type=EOF , sid=1, service="a.b.c", method="xxx"}: server close stream +*/ + +const ( + frameMagic uint32 = 0x123321 + // meta: new stream + frameTypeMeta = 0 + // data: stream streamSend/streamRecv data + frameTypeData = 1 + // eof: stream closed by peer + frameTypeEOF = 2 +) + +// Frame define a JSON RPC protocol frame +// - 4 bytes: frameMagic +// - 4 bytes: data size +// - 4 bytes: frame kind +// - 4 bytes: stream id +// - 4 bytes: service name size +// - service_name_size bytes: service name +// - 4 bytes: method name size +// - method_name_size bytes: method name +// - ... bytes: json payload +type Frame struct { + typ int + sid int + service string + method string + payload []byte +} + +func newFrame(typ int, sid int, service, method string, payload []byte) Frame { + return Frame{ + typ: typ, + sid: sid, + service: service, + method: method, + payload: payload, + } +} + +func EncodeFrame(writer io.Writer, frame Frame) (err error) { + // not include data size field length + dataSize := 4*4 + len(frame.service) + len(frame.method) + len(frame.payload) + data := make([]byte, 4+4+dataSize) + offset := 0 + + // header + binary.BigEndian.PutUint32(data[offset:offset+4], frameMagic) + offset += 4 + binary.BigEndian.PutUint32(data[offset:offset+4], uint32(dataSize)) + offset += 4 + + // data + binary.BigEndian.PutUint32(data[offset:offset+4], uint32(frame.typ)) + offset += 4 + binary.BigEndian.PutUint32(data[offset:offset+4], uint32(frame.sid)) + offset += 4 + binary.BigEndian.PutUint32(data[offset:offset+4], uint32(len(frame.service))) + offset += 4 + copy(data[offset:offset+len(frame.service)], frame.service) + offset += len(frame.service) + binary.BigEndian.PutUint32(data[offset:offset+4], uint32(len(frame.method))) + offset += 4 + copy(data[offset:offset+len(frame.method)], frame.method) + offset += len(frame.method) + copy(data[offset:offset+len(frame.payload)], frame.payload) + offset += len(frame.payload) + + idx := 0 + for idx < len(data) { + n, err := writer.Write(data[idx:]) + if err != nil { + return err + } + idx += n + } + return nil +} + +func EncodePayload(msg any) ([]byte, error) { + return json.Marshal(msg) +} + +func DecodeFrame(reader io.Reader) (frame Frame, err error) { + header := make([]byte, 8) + _, err = io.ReadFull(reader, header) + if err != nil { + return + } + magic := binary.BigEndian.Uint32(header[:4]) + size := binary.BigEndian.Uint32(header[4:8]) + if magic != frameMagic { + err = fmt.Errorf("invalid frame magic number: %d", magic) + return + } + + data := make([]byte, size) + _, err = io.ReadFull(reader, data) + if err != nil { + return + } + offset := 0 + frame.typ = int(binary.BigEndian.Uint32(data[offset : offset+4])) + offset += 4 + frame.sid = int(binary.BigEndian.Uint32(data[offset : offset+4])) + offset += 4 + serviceSize := int(binary.BigEndian.Uint32(data[offset : offset+4])) + offset += 4 + frame.service = string(data[offset : offset+serviceSize]) + offset += serviceSize + methodSize := int(binary.BigEndian.Uint32(data[offset : offset+4])) + offset += 4 + frame.method = string(data[offset : offset+methodSize]) + offset += methodSize + frame.payload = data[offset:] + return +} + +func DecodePayload(payload []byte, msg any) (err error) { + return json.Unmarshal(payload, msg) +} + +func checkFrame(conn netpoll.Connection) error { + header, err := conn.Reader().Peek(8) + if err != nil { + return err + } + magic := binary.BigEndian.Uint32(header[:4]) + if magic != frameMagic { + return fmt.Errorf("invalid frame magic number: %d", magic) + } + return nil +} diff --git a/pkg/streamx/provider/jsonrpc/server_option.go b/pkg/streamx/provider/jsonrpc/server_option.go new file mode 100644 index 0000000000..6282ebe446 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/server_option.go @@ -0,0 +1,9 @@ +package jsonrpc + +type ServerProviderOption func(pc *serverProvider) + +func WithServerPayloadLimit(limit int) ServerProviderOption { + return func(s *serverProvider) { + s.payloadLimit = limit + } +} diff --git a/pkg/streamx/provider/jsonrpc/server_provider.go b/pkg/streamx/provider/jsonrpc/server_provider.go new file mode 100644 index 0000000000..f38a7e0834 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/server_provider.go @@ -0,0 +1,61 @@ +package jsonrpc + +import ( + "context" + "net" + + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/netpoll" +) + +type serverTransCtxKey struct{} + +func NewServerProvider(sinfo *serviceinfo.ServiceInfo, opts ...ServerProviderOption) (streamx.ServerProvider, error) { + sp := new(serverProvider) + sp.sinfo = sinfo + for _, opt := range opts { + opt(sp) + } + return sp, nil +} + +var _ streamx.ServerProvider = (*serverProvider)(nil) + +type serverProvider struct { + sinfo *serviceinfo.ServiceInfo + payloadLimit int +} + +func (s serverProvider) Available(ctx context.Context, conn net.Conn) bool { + err := checkFrame(conn.(netpoll.Connection)) + return err == nil +} + +func (s serverProvider) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { + trans := newTransport(s.sinfo, conn) + return context.WithValue(ctx, serverTransCtxKey{}, trans), nil +} + +func (s serverProvider) OnInactive(ctx context.Context, conn net.Conn) (context.Context, error) { + return ctx, nil +} + +func (s serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Context, streamx.ServerStream, error) { + trans, _ := ctx.Value(serverTransCtxKey{}).(*transport) + if trans == nil { + return nil, nil, nil + } + st, err := trans.readStream() + if err != nil { + return nil, nil, err + } + ss := newServerStream(st) + return ctx, ss, nil +} + +func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream) (context.Context, error) { + sst := ss.(*serverStream) + err := sst.sendEOF() + return ctx, err +} diff --git a/pkg/streamx/provider/jsonrpc/stream.go b/pkg/streamx/provider/jsonrpc/stream.go new file mode 100644 index 0000000000..c9bf141905 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/stream.go @@ -0,0 +1,122 @@ +package jsonrpc + +import ( + "context" + "log" + "sync/atomic" + + "github.com/cloudwego/kitex/pkg/streamx" +) + +var ( + _ streamx.ClientStream = (*clientStream)(nil) + _ streamx.ServerStream = (*serverStream)(nil) + _ streamx.ClientStreamMetadata[Header, Trailer] = (*clientStream)(nil) + _ streamx.ServerStreamMetadata[Header, Trailer] = (*serverStream)(nil) +) + +func newStream(trans *transport, sid int, mode streamx.StreamingMode, service, method string) (s *stream) { + s = new(stream) + s.id = sid + s.mode = mode + s.service = service + s.method = method + s.trans = trans + return s +} + +type stream struct { + id int + mode streamx.StreamingMode + service string + method string + selfEOF int32 + peerEOF int32 + trans *transport +} + +func (s *stream) Header() (Header, error) { + return make(Header), nil +} + +func (s *stream) Trailer() (Trailer, error) { + return make(Trailer), nil +} + +func (s *stream) Mode() streamx.StreamingMode { + return s.mode +} + +func (s *stream) Service() string { + return s.service +} + +func (s *stream) Method() string { + return s.method +} + +func (s *stream) sendEOF() (err error) { + if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { + return nil + } + log.Printf("stream[%s] send EOF", s.method) + return s.trans.streamCloseSend(s) +} + +func (s *stream) recvEOF() (err error) { + if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { + return nil + } + log.Printf("stream[%s] recv EOF", s.method) + return s.trans.streamCloseRecv(s) +} + +func (s *stream) SendMsg(ctx context.Context, res any) error { + payload, err := EncodePayload(res) + if err != nil { + return err + } + return s.trans.streamSend(s, payload) +} + +func (s *stream) RecvMsg(ctx context.Context, req any) error { + payload, err := s.trans.streamRecv(s) + if err != nil { + return err + } + return DecodePayload(payload, req) +} + +func newClientStream(s *stream) *clientStream { + cs := &clientStream{stream: s} + return cs +} + +type clientStream struct { + *stream +} + +func (s *clientStream) CloseSend(ctx context.Context) error { + return s.sendEOF() +} + +func newServerStream(s *stream) streamx.ServerStream { + ss := &serverStream{stream: s} + return ss +} + +type serverStream struct { + *stream +} + +func (s *serverStream) SetHeader(hd Header) error { + return nil +} + +func (s *serverStream) SendHeader(hd Header) error { + return nil +} + +func (s *serverStream) SetTrailer(hd Trailer) error { + return nil +} diff --git a/pkg/streamx/provider/jsonrpc/transport.go b/pkg/streamx/provider/jsonrpc/transport.go new file mode 100644 index 0000000000..d36b6131c8 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/transport.go @@ -0,0 +1,153 @@ +package jsonrpc + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/netpoll" +) + +type transport struct { + sinfo *serviceinfo.ServiceInfo + conn net.Conn + streams sync.Map + sch chan *stream + rch map[int]chan Frame + wch chan Frame + stop chan struct{} +} + +func newTransport(sinfo *serviceinfo.ServiceInfo, conn net.Conn) *transport { + t := &transport{ + sinfo: sinfo, + conn: conn, + streams: sync.Map{}, + sch: make(chan *stream), + rch: map[int]chan Frame{}, + wch: make(chan Frame), + stop: make(chan struct{}), + } + go func() { + err := t.loopRead() + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) && !errors.Is(err, netpoll.ErrConnClosed) { + klog.Debugf("transport loop read err: %v", err) + } + }() + go func() { + err := t.loopWrite() + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { + klog.Debugf("transport loop write err: %v", err) + } + }() + return t +} + +func (t *transport) close() (err error) { + close(t.stop) + return nil +} + +func (t *transport) streamSend(s *stream, payload []byte) (err error) { + f := newFrame(frameTypeData, s.id, s.service, s.method, payload) + t.wch <- f + return nil +} + +func (t *transport) streamRecv(s *stream) (payload []byte, err error) { + f := <-t.rch[s.id] + if f.sid != s.id { // f.sid == 0 means it's a empty frame + return nil, io.EOF + } + return f.payload, nil +} + +func (t *transport) loopRead() error { + for { + // decode frame + frame, err := DecodeFrame(t.conn) + if err != nil { + return err + } + + // prepare stream + switch frame.typ { + case frameTypeMeta: // new stream + smode := t.sinfo.MethodInfo(frame.method).StreamingMode() + s := newStream(t, frame.sid, smode, frame.service, frame.method) + t.streams.Store(s.id, s) + t.rch[s.id] = make(chan Frame, 1024) + t.sch <- s + case frameTypeData, frameTypeEOF: // stream streamRecv/close + iss, ok := t.streams.Load(frame.sid) + if !ok { + return fmt.Errorf("stream not found in stream map: sid=%d", frame.sid) + } + s := iss.(*stream) + switch frame.typ { + case frameTypeEOF: + err = s.recvEOF() + return err + case frameTypeData: + // process data frame + t.rch[s.id] <- frame + } + } + } +} + +func (t *transport) loopWrite() error { + for { + select { + case <-t.stop: + return nil + case frame := <-t.wch: + err := EncodeFrame(t.conn, frame) + if err != nil { + return err + } + } + } +} + +var clientStreamID uint32 + +func (t *transport) newStream(method string) (*stream, error) { + sid := int(atomic.AddUint32(&clientStreamID, 1)) + smode := t.sinfo.MethodInfo(method).StreamingMode() + service := t.sinfo.ServiceName + f := newFrame(frameTypeMeta, sid, service, method, []byte{}) + s := newStream(t, sid, smode, service, method) + t.streams.Store(s.id, s) + t.rch[s.id] = make(chan Frame, 1024) + t.wch <- f // create stream + return s, nil +} + +func (t *transport) streamCloseRecv(s *stream) (err error) { + //for len(t.rch[s.id]) > 0 { + // runtime.Gosched() + //} + close(t.rch[s.id]) + return nil +} + +func (t *transport) streamCloseSend(s *stream) (err error) { + f := newFrame(frameTypeEOF, s.id, s.service, s.method, []byte("EOF")) + t.wch <- f + return nil +} + +func (t *transport) readStream() (*stream, error) { + select { + case <-t.stop: + return nil, io.EOF + case s := <-t.sch: + return s, nil + } +} diff --git a/pkg/streamx/provider/jsonrpc/transport_test.go b/pkg/streamx/provider/jsonrpc/transport_test.go new file mode 100644 index 0000000000..4df6e66ac4 --- /dev/null +++ b/pkg/streamx/provider/jsonrpc/transport_test.go @@ -0,0 +1,143 @@ +package jsonrpc + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +func TestCodec(t *testing.T) { + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + f1 := newFrame(0, 1, "a.b.c", "test", []byte("12345")) + err := EncodeFrame(writer, f1) + test.Assert(t, err == nil, err) + _ = writer.Flush() + reader := bufio.NewReader(&buf) + f2, err := DecodeFrame(reader) + test.Assert(t, err == nil, err) + test.Assert(t, f2.method == f1.method, f2.method) + test.Assert(t, string(f2.payload) == string(f1.payload), f2.payload) +} + +func TestTransport(t *testing.T) { + type TestRequest struct { + A int `json:"A,omitempty"` + B string `json:"B,omitempty"` + } + type TestResponse = TestRequest + method := "BidiStream" + sinfo := &serviceinfo.ServiceInfo{ + ServiceName: "a.b.c", + Methods: map[string]serviceinfo.MethodInfo{ + method: serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return nil + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), + }, + Extra: map[string]interface{}{"streaming": true}, + } + + var addr = test.GetLocalAddress() + ln, err := net.Listen("tcp", addr) + test.Assert(t, err == nil, err) + + // Server + var connDone int32 + var streamDone int32 + go func() { + for { + conn, err := ln.Accept() + if (conn == nil && err == nil) || errors.Is(err, net.ErrClosed) { + return + } + test.Assert(t, err == nil, err) + go func() { + defer atomic.AddInt32(&connDone, -1) + server := newTransport(sinfo, conn) + st, nerr := server.readStream() + if nerr != nil { + if nerr == io.EOF { + return + } + t.Error(nerr) + } + go func() { + defer atomic.AddInt32(&streamDone, -1) + for { + ctx := context.Background() + req := new(TestRequest) + nerr = st.RecvMsg(ctx, req) + if errors.Is(nerr, io.EOF) { + return + } + t.Logf("server recv msg: %v %v", req, nerr) + res := req + nerr = st.SendMsg(ctx, res) + t.Logf("server send msg: %v %v", res, nerr) + if nerr != nil { + if nerr == io.EOF { + return + } + t.Error(nerr) + } + } + }() + }() + } + }() + time.Sleep(time.Millisecond * 100) + + // Client + atomic.AddInt32(&connDone, 1) + conn, err := net.Dial("tcp", addr) + test.Assert(t, err == nil, err) + trans := newTransport(sinfo, conn) + s, err := trans.newStream(method) + test.Assert(t, err == nil, err) + cs := newClientStream(s) + + req := new(TestRequest) + req.A = 12345 + req.B = "hello" + res := new(TestResponse) + ctx := context.Background() + err = cs.SendMsg(ctx, req) + t.Logf("client send msg: %v", req) + test.Assert(t, err == nil, err) + err = cs.RecvMsg(ctx, res) + t.Logf("client recv msg: %v", res) + test.Assert(t, err == nil, err) + test.Assert(t, req.A == res.A, res) + test.Assert(t, req.B == res.B, res) + + // close stream + err = cs.CloseSend(ctx) + test.Assert(t, err == nil, err) + for atomic.LoadInt32(&streamDone) != 0 { + time.Sleep(time.Millisecond * 10) + } + + // close conn + err = trans.close() + test.Assert(t, err == nil, err) + err = ln.Close() + test.Assert(t, err == nil, err) + for atomic.LoadInt32(&connDone) != 0 { + time.Sleep(time.Millisecond * 10) + } +} diff --git a/pkg/streamx/provider/ttstream/client_option.go b/pkg/streamx/provider/ttstream/client_option.go new file mode 100644 index 0000000000..6d0e3a4db3 --- /dev/null +++ b/pkg/streamx/provider/ttstream/client_option.go @@ -0,0 +1,9 @@ +package ttstream + +type ClientProviderOption func(cp *clientProvider) + +func WithClientMetaHandler(metaHandler MetaFrameHandler) ClientProviderOption { + return func(cp *clientProvider) { + cp.metaHandler = metaHandler + } +} diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go new file mode 100644 index 0000000000..cf830b1248 --- /dev/null +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -0,0 +1,67 @@ +package ttstream + +import ( + "context" + "runtime" + + "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" +) + +var _ streamx.ClientProvider = (*clientProvider)(nil) + +func NewClientProvider(sinfo *serviceinfo.ServiceInfo, opts ...ClientProviderOption) (streamx.ClientProvider, error) { + cp := new(clientProvider) + cp.sinfo = sinfo + for _, opt := range opts { + opt(cp) + } + cp.transPool = newTransPool(sinfo) + return cp, nil +} + +type clientProvider struct { + transPool *transPool + sinfo *serviceinfo.ServiceInfo + metaHandler MetaFrameHandler +} + +func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) { + invocation := ri.Invocation() + method := invocation.MethodName() + addr := ri.To().Address() + if addr == nil { + return nil, kerrors.ErrNoDestAddress + } + + trans, err := c.transPool.Get(addr.Network(), addr.String()) + if err != nil { + return nil, err + } + + header := map[string]string{ + ttheader.HeaderIDLServiceName: c.sinfo.ServiceName, + } + metainfo.SaveMetaInfoToMap(ctx, header) + s, err := trans.newStream(ctx, method, header) + if err != nil { + return nil, err + } + // only client can set meta frame handler + s.setMetaFrameHandler(c.metaHandler) + cs := newClientStream(s) + runtime.SetFinalizer(cs, func(cs *clientStream) { + klog.Debugf("client stream[%v] closing", cs.sid) + _ = cs.close() + // TODO: currently using one conn one stream at same time + //_ = trans.close() + c.transPool.Put(trans) + }) + return cs, err +} diff --git a/pkg/streamx/provider/ttstream/client_trans_pool.go b/pkg/streamx/provider/ttstream/client_trans_pool.go new file mode 100644 index 0000000000..f8103e73a5 --- /dev/null +++ b/pkg/streamx/provider/ttstream/client_trans_pool.go @@ -0,0 +1,105 @@ +package ttstream + +import ( + "runtime" + "sync" + "time" + + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/netpoll" +) + +// TODO: it's to complex for users implement idle check +// so let's implement it in netpoll + +func newTransStack() *transStack { + return &transStack{} +} + +// FILO +type transStack struct { + mu sync.Mutex + stack []*transport // TODO: now it's a mem leak stack implementation + modified time.Time +} + +func (s *transStack) Pop() (trans *transport) { + s.mu.Lock() + if len(s.stack) == 0 { + s.mu.Unlock() + return nil + } + trans = s.stack[len(s.stack)-1] + s.stack = s.stack[:len(s.stack)-1] + s.mu.Unlock() + return trans +} + +func (s *transStack) Push(trans *transport) { + s.mu.Lock() + s.stack = append(s.stack, trans) + s.modified = time.Now() + s.mu.Unlock() +} + +func (s *transStack) Clear() { + s.mu.Lock() + s.stack = []*transport{} + s.modified = time.Now() + s.mu.Unlock() +} + +func newTransPool(sinfo *serviceinfo.ServiceInfo) *transPool { + tp := &transPool{sinfo: sinfo} + go func() { + now := time.Now() + deleteKeys := make([]string, 0) + tp.pool.Range(func(addr, value any) bool { + tstack := value.(*transStack) + duration := now.Sub(tstack.modified) + if duration >= time.Minute*10 { + deleteKeys = append(deleteKeys, addr.(string)) + } + return true + }) + }() + return tp +} + +type transPool struct { + pool sync.Map // {"addr":*transStack} + sinfo *serviceinfo.ServiceInfo +} + +func (c *transPool) Get(network string, addr string) (trans *transport, err error) { + var cstack *transStack + val, ok := c.pool.Load(addr) + if !ok { + // TODO: here may have a race problem + cstack = newTransStack() + _, _ = c.pool.LoadOrStore(addr, cstack) + } else { + cstack = val.(*transStack) + } + trans = cstack.Pop() + if trans != nil { + return trans, nil + } + conn, err := netpoll.DialConnection(network, addr, time.Second) + if err != nil { + return nil, err + } + trans = newTransport(clientTransport, c.sinfo, conn) + runtime.SetFinalizer(trans, func(t *transport) { t.close() }) + return trans, nil +} + +func (c *transPool) Put(trans *transport) { + var cstack *transStack + val, ok := c.pool.Load(trans.conn.RemoteAddr()) + if !ok { + return + } + cstack = val.(*transStack) + cstack.Push(trans) +} diff --git a/pkg/streamx/provider/ttstream/frame.go b/pkg/streamx/provider/ttstream/frame.go new file mode 100644 index 0000000000..41309ee945 --- /dev/null +++ b/pkg/streamx/provider/ttstream/frame.go @@ -0,0 +1,130 @@ +package ttstream + +import ( + "context" + "encoding/binary" + "fmt" + + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/ttheader" +) + +const ( + metaFrameType int32 = 1 + headerFrameType int32 = 2 + dataFrameType int32 = 3 + trailerFrameType int32 = 4 +) + +var frameTypeToString = map[int32]string{ + metaFrameType: ttheader.FrameTypeMeta, + headerFrameType: ttheader.FrameTypeHeader, + dataFrameType: ttheader.FrameTypeData, + trailerFrameType: ttheader.FrameTypeTrailer, +} + +type Frame struct { + streamFrame + meta IntHeader + typ int32 + payload []byte +} + +func newFrame(meta streamFrame, typ int32, payload []byte) Frame { + return Frame{ + streamFrame: meta, + typ: typ, + payload: payload, + } +} + +func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr Frame) (err error) { + param := ttheader.EncodeParam{ + Flags: ttheader.HeaderFlagsStreaming, + SeqID: fr.sid, + ProtocolID: ttheader.ProtocolIDThriftStruct, + } + + param.IntInfo = fr.meta + if param.IntInfo == nil { + param.IntInfo = make(IntHeader) + } + param.IntInfo[ttheader.FrameType] = frameTypeToString[fr.typ] + param.IntInfo[ttheader.ToMethod] = fr.method + + switch fr.typ { + case headerFrameType: + param.StrInfo = fr.header + case trailerFrameType: + param.StrInfo = fr.trailer + } + + totalLenField, err := ttheader.Encode(ctx, param, writer) + if err != nil { + return err + } + if len(fr.payload) > 0 { + _, err = writer.WriteBinary(fr.payload) + if err != nil { + return err + } + } + binary.BigEndian.PutUint32(totalLenField, uint32(writer.WrittenLen()-4)) + err = writer.Flush() + return err +} + +func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr Frame, err error) { + var dp ttheader.DecodeParam + dp, err = ttheader.Decode(ctx, reader) + if err != nil { + return + } + if dp.Flags != ttheader.HeaderFlagsStreaming { + err = fmt.Errorf("unexpected header flags: %d", dp.Flags) + return + } + + fr.meta = dp.IntInfo + frtype := fr.meta[ttheader.FrameType] + switch frtype { + case ttheader.FrameTypeMeta: + fr.typ = metaFrameType + case ttheader.FrameTypeHeader: + fr.typ = headerFrameType + fr.header = dp.StrInfo + case ttheader.FrameTypeData: + fr.typ = dataFrameType + case ttheader.FrameTypeTrailer: + fr.typ = trailerFrameType + fr.trailer = dp.StrInfo + default: + err = fmt.Errorf("unexpected frame type: %v", fr.meta[ttheader.FrameType]) + return + } + // stream meta + fr.sid = dp.SeqID + fr.method = fr.meta[ttheader.ToMethod] + + // frame payload + if dp.PayloadLen == 0 { + return fr, nil + } + fr.payload = make([]byte, dp.PayloadLen) + _, err = reader.ReadBinary(fr.payload) + reader.Release(err) + if err != nil { + return + } + return fr, nil +} + +func EncodePayload(ctx context.Context, msg any) ([]byte, error) { + return thrift.FastMarshal(msg.(thrift.FastCodec)), nil +} + +func DecodePayload(ctx context.Context, payload []byte, msg any) error { + err := thrift.FastUnmarshal(payload, msg.(thrift.FastCodec)) + return err +} diff --git a/pkg/streamx/provider/ttstream/frame_test.go b/pkg/streamx/provider/ttstream/frame_test.go new file mode 100644 index 0000000000..dc0d4d29e0 --- /dev/null +++ b/pkg/streamx/provider/ttstream/frame_test.go @@ -0,0 +1,51 @@ +package ttstream + +import ( + "context" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote" +) + +func TestFrameCodec(t *testing.T) { + rw := remote.NewReaderWriterBuffer(1024) + + wframe := newFrame(streamFrame{ + sid: 1, + method: "method", + header: map[string]string{"key": "value"}, + }, headerFrameType, []byte("hello world")) + err := EncodeFrame(context.Background(), rw, wframe) + test.Assert(t, err == nil, err) + + rframe, err := DecodeFrame(context.Background(), rw) + test.Assert(t, err == nil, err) + test.DeepEqual(t, string(wframe.payload), string(rframe.payload)) + test.DeepEqual(t, wframe, rframe) +} + +func TestFrameWithoutPayloadCodec(t *testing.T) { + rmsg := new(TestRequest) + rmsg.A = 1 + payload, err := EncodePayload(context.Background(), rmsg) + test.Assert(t, err == nil, err) + + wmsg := new(TestRequest) + err = DecodePayload(context.Background(), payload, wmsg) + test.Assert(t, err == nil, err) + test.DeepEqual(t, wmsg, rmsg) +} + +func TestPayloadCodec(t *testing.T) { + rmsg := new(TestRequest) + rmsg.A = 1 + rmsg.B = "hello world" + payload, err := EncodePayload(context.Background(), rmsg) + test.Assert(t, err == nil, err) + + wmsg := new(TestRequest) + err = DecodePayload(context.Background(), payload, wmsg) + test.Assert(t, err == nil, err) + test.DeepEqual(t, wmsg, rmsg) +} diff --git a/pkg/streamx/provider/ttstream/meta_frame_handler.go b/pkg/streamx/provider/ttstream/meta_frame_handler.go new file mode 100644 index 0000000000..2c193cf702 --- /dev/null +++ b/pkg/streamx/provider/ttstream/meta_frame_handler.go @@ -0,0 +1,50 @@ +package ttstream + +import "sync" + +type StreamMeta interface { + Meta() map[string]string + GetMeta(k string) (string, bool) + SetMeta(k string, v string, kvs ...string) +} + +type MetaFrameHandler interface { + OnMetaFrame(smeta StreamMeta, intHeader IntHeader, header Header, payload []byte) error +} + +var _ StreamMeta = (*streamMeta)(nil) + +func newStreamMeta() StreamMeta { + return &streamMeta{} +} + +type streamMeta struct { + sync sync.RWMutex + data map[string]string +} + +func (s *streamMeta) Meta() map[string]string { + s.sync.RLock() + m := make(map[string]string, len(s.data)) + for k, v := range s.data { + m[k] = v + } + s.sync.RUnlock() + return m +} + +func (s *streamMeta) GetMeta(k string) (string, bool) { + s.sync.RLock() + v, ok := s.data[k] + s.sync.RUnlock() + return v, ok +} + +func (s *streamMeta) SetMeta(k string, v string, kvs ...string) { + s.sync.RLock() + s.data[k] = v + for i := 0; i < len(kvs); i += 2 { + s.data[kvs[i]] = kvs[i+1] + } + s.sync.RUnlock() +} diff --git a/pkg/streamx/provider/ttstream/metadata.go b/pkg/streamx/provider/ttstream/metadata.go new file mode 100644 index 0000000000..5d9f7973d5 --- /dev/null +++ b/pkg/streamx/provider/ttstream/metadata.go @@ -0,0 +1,30 @@ +package ttstream + +import ( + "errors" + + "github.com/cloudwego/kitex/pkg/streamx" +) + +var ErrInvalidStreamKind = errors.New("invalid stream kind") + +type Header map[string]string +type Trailer map[string]string + +// only for meta frame handler +type IntHeader map[uint16]string + +// ClientStreamMeta cannot send header directly, should send from ctx +type ClientStreamMeta interface { + streamx.ClientStream + Header() (Header, error) + Trailer() (Trailer, error) +} + +// ServerStreamMeta cannot read header directly, should read from ctx +type ServerStreamMeta interface { + streamx.ServerStream + SetHeader(hd Header) error + SendHeader(hd Header) error + SetTrailer(hd Trailer) error +} diff --git a/pkg/streamx/provider/ttstream/mock_test.go b/pkg/streamx/provider/ttstream/mock_test.go new file mode 100644 index 0000000000..d289696a8b --- /dev/null +++ b/pkg/streamx/provider/ttstream/mock_test.go @@ -0,0 +1,49 @@ +package ttstream + +import ( + "encoding/json" + "fmt" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + kutils "github.com/cloudwego/kitex/pkg/utils" +) + +type TestRequest struct { + A int32 `thrift:"A,1" frugal:"1,default,i32" json:"A"` + B string `thrift:"B,2" frugal:"2,default,string" json:"B"` +} + +func (p *TestRequest) FastRead(buf []byte) (int, error) { + err := json.Unmarshal(buf, p) + if err != nil { + return 0, err + } + return len(buf), nil +} + +func (p *TestRequest) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { + data, _ := json.Marshal(p) + copy(buf, data) + return len(data) +} + +func (p *TestRequest) BLength() int { + data, _ := json.Marshal(p) + return len(data) +} + +func (p *TestRequest) DeepCopy(s interface{}) error { + src, ok := s.(*TestRequest) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + p.A = src.A + + if src.B != "" { + p.B = kutils.StringDeepCopy(src.B) + } + + return nil +} + +type TestResponse = TestRequest diff --git a/pkg/streamx/provider/ttstream/server_option.go b/pkg/streamx/provider/ttstream/server_option.go new file mode 100644 index 0000000000..5c8e1a9455 --- /dev/null +++ b/pkg/streamx/provider/ttstream/server_option.go @@ -0,0 +1,9 @@ +package ttstream + +type ServerProviderOption func(pc *serverProvider) + +func WithServerPayloadLimit(limit int) ServerProviderOption { + return func(s *serverProvider) { + s.payloadLimit = limit + } +} diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go new file mode 100644 index 0000000000..7a8c490bf0 --- /dev/null +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -0,0 +1,81 @@ +package ttstream + +import ( + "context" + "net" + + "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/netpoll" +) + +type serverTransCtxKey struct{} + +func NewServerProvider(sinfo *serviceinfo.ServiceInfo, opts ...ServerProviderOption) (streamx.ServerProvider, error) { + sp := new(serverProvider) + sp.sinfo = sinfo + for _, opt := range opts { + opt(sp) + } + return sp, nil +} + +var _ streamx.ServerProvider = (*serverProvider)(nil) + +type serverProvider struct { + sinfo *serviceinfo.ServiceInfo + payloadLimit int +} + +func (s serverProvider) Available(ctx context.Context, conn net.Conn) bool { + data, err := conn.(netpoll.Connection).Reader().Peek(8) + if err != nil { + return false + } + return ttheader.IsStreaming(data) +} + +func (s serverProvider) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { + trans := newTransport(serverTransport, s.sinfo, conn.(netpoll.Connection)) + return context.WithValue(ctx, serverTransCtxKey{}, trans), nil +} + +func (s serverProvider) OnInactive(ctx context.Context, conn net.Conn) (context.Context, error) { + trans, _ := ctx.Value(serverTransCtxKey{}).(*transport) + if trans == nil { + return ctx, nil + } + // server should close transport + err := trans.close() + if err != nil { + return nil, err + } + return ctx, nil +} + +func (s serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Context, streamx.ServerStream, error) { + trans, _ := ctx.Value(serverTransCtxKey{}).(*transport) + if trans == nil { + return nil, nil, nil + } + st, err := trans.readStream() + if err != nil { + return nil, nil, err + } + ctx = metainfo.SetMetaInfoFromMap(ctx, st.header) + ss := newServerStream(st) + return ctx, ss, nil +} + +func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream) (context.Context, error) { + sst := ss.(*serverStream) + if err := sst.sendTrailer(); err != nil { + return nil, err + } + if err := sst.close(); err != nil { + return nil, err + } + return ctx, nil +} diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go new file mode 100644 index 0000000000..458e79ee95 --- /dev/null +++ b/pkg/streamx/provider/ttstream/stream.go @@ -0,0 +1,203 @@ +package ttstream + +import ( + "context" + "errors" + "sync/atomic" + + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/streamx" +) + +var ( + _ streamx.ClientStream = (*clientStream)(nil) + _ streamx.ServerStream = (*serverStream)(nil) + _ streamx.ClientStreamMetadata[Header, Trailer] = (*clientStream)(nil) + _ streamx.ServerStreamMetadata[Header, Trailer] = (*serverStream)(nil) + _ StreamMeta = (*stream)(nil) +) + +func newStream(trans *transport, mode streamx.StreamingMode, smeta streamFrame) (s *stream) { + s = new(stream) + s.streamFrame = smeta + s.trans = trans + s.mode = mode + s.headerSig = make(chan struct{}) + s.trailerSig = make(chan struct{}) + s.StreamMeta = newStreamMeta() + trans.storeStreamIO(s) + return s +} + +type streamFrame struct { + sid int32 + method string + header Header // key:value, key is full name + trailer Trailer +} + +type stream struct { + streamFrame + trans *transport + mode streamx.StreamingMode + wheader Header + wtrailer Trailer + selfEOF int32 + peerEOF int32 + headerSig chan struct{} + trailerSig chan struct{} + + StreamMeta + metaHandler MetaFrameHandler +} + +func (s *stream) Mode() streamx.StreamingMode { + return s.mode +} + +func (s *stream) Service() string { + if len(s.header) == 0 { + return "" + } + return s.header[ttheader.HeaderIDLServiceName] +} + +func (s *stream) Method() string { + return s.method +} + +func (s *stream) setMetaFrameHandler(h MetaFrameHandler) { + s.metaHandler = h +} + +func (s *stream) readMetaFrame(intHeader IntHeader, header Header, payload []byte) (err error) { + if s.metaHandler == nil { + return nil + } + return s.metaHandler.OnMetaFrame(s.StreamMeta, intHeader, header, payload) +} + +func (s *stream) readHeader(hd Header) (err error) { + s.header = hd + select { + case <-s.headerSig: + return errors.New("already set header") + default: + close(s.headerSig) + } + klog.Debugf("stream[%s] read header: %v", s.method, hd) + return nil +} + +func (s *stream) writeHeader(hd Header) (err error) { + if s.wheader == nil { + s.wheader = make(Header) + } + for k, v := range hd { + s.wheader[k] = v + } + return nil +} + +func (s *stream) sendHeader() (err error) { + wheader := s.wheader + s.wheader = nil + err = s.trans.streamSendHeader(s.sid, s.method, wheader) + return err +} + +// readTrailer by client: unblock recv function and return EOF if no unread frame +// readTrailer by server: unblock recv function and return EOF if no unread frame +func (s *stream) readTrailer(tl Trailer) (err error) { + if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { + return nil + } + + s.trailer = tl + select { + case <-s.trailerSig: + return errors.New("already set trailer") + default: + close(s.trailerSig) + } + + klog.Debugf("stream[%d] recv trailer: %v", s.sid, tl) + return s.trans.streamCloseRecv(s) +} + +func (s *stream) writeTrailer(tl Trailer) (err error) { + if s.wtrailer == nil { + s.wtrailer = make(Trailer) + } + for k, v := range tl { + s.wtrailer[k] = v + } + return nil +} + +func (s *stream) sendTrailer() (err error) { + if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { + return nil + } + klog.Debugf("stream[%d] send trialer", s.sid) + return s.trans.streamSendTrailer(s.sid, s.method, s.wtrailer) +} + +func (s *stream) SendMsg(ctx context.Context, res any) error { + payload, err := EncodePayload(ctx, res) + if err != nil { + return err + } + return s.trans.streamSend(s.sid, s.method, s.wheader, payload) +} + +func (s *stream) RecvMsg(ctx context.Context, req any) error { + payload, err := s.trans.streamRecv(s.sid) + if err != nil { + return err + } + err = DecodePayload(ctx, payload, req.(thrift.FastCodec)) + return err +} + +func newClientStream(s *stream) *clientStream { + cs := &clientStream{stream: s} + return cs +} + +type clientStream struct { + *stream +} + +func (s *clientStream) CloseSend(ctx context.Context) error { + return s.sendTrailer() +} + +func (s *clientStream) close() error { + return s.trans.streamClose(s.stream) +} + +func newServerStream(s *stream) streamx.ServerStream { + ss := &serverStream{stream: s} + return ss +} + +type serverStream struct { + *stream +} + +func (s *serverStream) close() error { + return s.trans.streamClose(s.stream) +} + +// SendMsg should send left header first +func (s *serverStream) SendMsg(ctx context.Context, res any) error { + if len(s.wheader) > 0 { + if err := s.sendHeader(); err != nil { + return err + } + } + return s.stream.SendMsg(ctx, res) +} diff --git a/pkg/streamx/provider/ttstream/stream_header_trailer.go b/pkg/streamx/provider/ttstream/stream_header_trailer.go new file mode 100644 index 0000000000..5a91294dd9 --- /dev/null +++ b/pkg/streamx/provider/ttstream/stream_header_trailer.go @@ -0,0 +1,30 @@ +package ttstream + +var _ ClientStreamMeta = (*clientStream)(nil) +var _ ServerStreamMeta = (*serverStream)(nil) + +func (s *clientStream) Header() (Header, error) { + <-s.headerSig + return s.header, nil +} + +func (s *clientStream) Trailer() (Trailer, error) { + <-s.trailerSig + return s.trailer, nil +} + +func (s *serverStream) SetHeader(hd Header) error { + return s.writeHeader(hd) +} + +func (s *serverStream) SendHeader(hd Header) error { + err := s.writeHeader(hd) + if err != nil { + return err + } + return s.stream.sendHeader() +} + +func (s *serverStream) SetTrailer(tl Trailer) error { + return s.writeTrailer(tl) +} diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go new file mode 100644 index 0000000000..4680e09850 --- /dev/null +++ b/pkg/streamx/provider/ttstream/stream_io.go @@ -0,0 +1,48 @@ +package ttstream + +import ( + "io" + "sync" +) + +type streamIO struct { + stream *stream + cond *sync.Cond + frames []Frame // TODO: using link list + end bool +} + +func newStreamIO(s *stream) *streamIO { + var lock sync.Mutex + var cond = sync.NewCond(&lock) + return &streamIO{stream: s, cond: cond} +} + +func (s *streamIO) input(f Frame) { + s.cond.L.Lock() + s.frames = append(s.frames, f) + s.cond.L.Unlock() + s.cond.Signal() +} + +func (s *streamIO) output() (f Frame, err error) { + s.cond.L.Lock() + for len(s.frames) == 0 && !s.end { + s.cond.Wait() + } + // have incoming frames or eof + if len(s.frames) == 0 && s.end { + return f, io.EOF + } + f = s.frames[0] + s.frames = s.frames[1:] + s.cond.L.Unlock() + return f, nil +} + +func (s *streamIO) eof() { + s.cond.L.Lock() + s.end = true + s.cond.L.Unlock() + s.cond.Signal() +} diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go new file mode 100644 index 0000000000..650e8d24dc --- /dev/null +++ b/pkg/streamx/provider/ttstream/transport.go @@ -0,0 +1,261 @@ +package ttstream + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/netpoll" +) + +const ( + clientTransport int32 = 1 + serverTransport int32 = 2 +) + +type transport struct { + kind int32 + sinfo *serviceinfo.ServiceInfo + conn netpoll.Connection + reader bufiox.Reader + writer bufiox.Writer + streams sync.Map // key=streamID val=streamIO + sch chan *stream // in-coming stream channel + wch chan Frame // out-coming frame channel + stop chan struct{} +} + +func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection) *transport { + _ = conn.SetDeadline(time.Now().Add(time.Hour)) + reader := bufiox.NewDefaultReader(conn) + writer := bufiox.NewDefaultWriter(conn) + t := &transport{ + kind: kind, + sinfo: sinfo, + conn: conn, + reader: reader, + writer: writer, + streams: sync.Map{}, + sch: make(chan *stream, 8), + wch: make(chan Frame, 8), + stop: make(chan struct{}), + } + go func() { + err := t.loopRead() + if err != nil && errors.Is(err, io.EOF) { + klog.Warnf("trans loop read err: %v", err) + } + }() + go func() { + err := t.loopWrite() + if err != nil && errors.Is(err, io.EOF) { + klog.Warnf("trans loop write err: %v", err) + return + } + }() + return t +} + +func (t *transport) storeStreamIO(s *stream) { + t.streams.Store(s.sid, newStreamIO(s)) +} + +func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { + val, ok := t.streams.Load(sid) + if !ok { + return sio, false + } + sio = val.(*streamIO) + return sio, true +} + +func (t *transport) loopRead() error { + for { + // decode frame + fr, err := DecodeFrame(context.Background(), t.reader) + if err != nil { + return err + } + + switch fr.typ { + case metaFrameType: + sio, ok := t.loadStreamIO(fr.sid) + if !ok { + return fmt.Errorf("transport[%d] read a unknown stream meta: sid=%d", t.kind, fr.sid) + } + err = sio.stream.readMetaFrame(fr.meta, fr.header, fr.payload) + if err != nil { + return err + } + case headerFrameType: + switch t.kind { + case serverTransport: + // Header Frame: server recv a new stream + smode := t.sinfo.MethodInfo(fr.method).StreamingMode() + s := newStream(t, smode, fr.streamFrame) + klog.Debugf("transport[%d] read a new stream: sid=%d", t.kind, s.sid) + t.sch <- s + case clientTransport: + // Header Frame: client recv header + sio, ok := t.loadStreamIO(fr.sid) + if !ok { + return fmt.Errorf("transport[%d] read a unknown stream header: sid=%d", t.kind, fr.sid) + } + err = sio.stream.readHeader(fr.header) + if err != nil { + return err + } + } + case dataFrameType: + // Data Frame: decode and distribute data + sio, ok := t.loadStreamIO(fr.sid) + if !ok { + return fmt.Errorf("transport[%d] read a unknown stream data: sid=%d", t.kind, fr.sid) + } + sio.input(fr) + case trailerFrameType: + // Trailer Frame: recv trailer, close read direction + sio, ok := t.loadStreamIO(fr.sid) + if !ok { + return fmt.Errorf("transport[%d] read a unknown stream trailer: sid=%d", t.kind, fr.sid) + } + if err = sio.stream.readTrailer(fr.trailer); err != nil { + return err + } + } + } +} + +func (t *transport) writeFrame(frame Frame) error { + err := EncodeFrame(context.Background(), t.writer, frame) + return err +} + +func (t *transport) loopWrite() error { + for { + select { + case <-t.stop: + // re-check wch queue + select { + case frame := <-t.wch: + if err := t.writeFrame(frame); err != nil { + return err + } + default: + return nil + } + case frame := <-t.wch: + if err := t.writeFrame(frame); err != nil { + return err + } + } + } +} + +func (t *transport) close() (err error) { + select { + case <-t.stop: + default: + klog.Warnf("transport[%s] is closing", t.conn.LocalAddr()) + close(t.stop) + t.conn.Close() + } + return nil +} + +func (t *transport) streamSend(sid int32, method string, wheader Header, payload []byte) (err error) { + if len(wheader) > 0 { + err = t.streamSendHeader(sid, method, wheader) + if err != nil { + return err + } + } + f := newFrame(streamFrame{sid: sid, method: method}, dataFrameType, payload) + t.wch <- f + return nil +} + +func (t *transport) streamSendHeader(sid int32, method string, header Header) (err error) { + f := newFrame(streamFrame{sid: sid, method: method, header: header}, headerFrameType, []byte{}) + t.wch <- f + return nil +} + +func (t *transport) streamSendTrailer(sid int32, method string, trailer Trailer) (err error) { + f := newFrame(streamFrame{sid: sid, method: method, trailer: trailer}, trailerFrameType, []byte{}) + t.wch <- f + return nil +} + +func (t *transport) streamRecv(sid int32) (payload []byte, err error) { + sio, ok := t.loadStreamIO(sid) + if !ok { + return nil, io.EOF + } + f, err := sio.output() + if err != nil { + return nil, err + } + return f.payload, nil +} + +func (t *transport) streamCloseRecv(s *stream) (err error) { + sio, ok := t.loadStreamIO(s.sid) + if !ok { + return fmt.Errorf("stream not found in stream map: sid=%d", s.sid) + } + sio.eof() + return nil +} + +func (t *transport) streamClose(s *stream) (err error) { + // remove stream from transport + t.streams.Delete(s.sid) + return nil +} + +var clientStreamID int32 + +// newStream create new stream on current connection +// it's typically used by client side +func (t *transport) newStream( + ctx context.Context, method string, header map[string]string) (*stream, error) { + if t.kind != clientTransport { + return nil, fmt.Errorf("transport already be used as other kind") + } + sid := atomic.AddInt32(&clientStreamID, 1) + smode := t.sinfo.MethodInfo(method).StreamingMode() + smeta := streamFrame{ + sid: sid, + method: method, + header: header, + } + f := newFrame(smeta, headerFrameType, []byte{}) + s := newStream(t, smode, smeta) + t.wch <- f // create stream + return s, nil +} + +// readStream wait for a new incoming stream on current connection +// it's typically used by server side +func (t *transport) readStream() (*stream, error) { + if t.kind != serverTransport { + return nil, fmt.Errorf("transport already be used as other kind") + } + select { + case <-t.stop: + return nil, io.EOF + case s := <-t.sch: + if s == nil { + return nil, io.EOF + } + return s, nil + } +} diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go new file mode 100644 index 0000000000..45c6848a45 --- /dev/null +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -0,0 +1,186 @@ +package ttstream + +import ( + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/netpoll" +) + +func TestTransport(t *testing.T) { + method := "BidiStream" + sinfo := &serviceinfo.ServiceInfo{ + ServiceName: "a.b.c", + Methods: map[string]serviceinfo.MethodInfo{ + method: serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return nil + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), + }, + Extra: map[string]interface{}{"streaming": true}, + } + + addr := test.GetLocalAddress() + ln, err := net.Listen("tcp", addr) + test.Assert(t, err == nil, err) + defer ln.Close() + + var connDone int32 + var streamDone int32 + svr, err := netpoll.NewEventLoop(nil, + netpoll.WithOnConnect(func(ctx context.Context, connection netpoll.Connection) context.Context { + t.Logf("OnConnect started") + defer t.Logf("OnConnect finished") + trans := newTransport(serverTransport, sinfo, connection) + t.Logf("OnRead started") + defer t.Log("OnRead finished") + + go func() { + for { + s, err := trans.readStream() + t.Logf("OnRead read stream: %v, %v", s, err) + if err != nil { + if err == io.EOF { + return + } + t.Error(err) + } + ss := newServerStream(s) + go func(st streamx.ServerStream) { + defer func() { + // set trailer + err := st.(ServerStreamMeta).SetTrailer(Trailer{"key": "val"}) + test.Assert(t, err == nil, err) + + // send trailer + err = ss.(*serverStream).sendTrailer() + test.Assert(t, err == nil, err) + atomic.AddInt32(&streamDone, -1) + }() + + // send header + err := st.(ServerStreamMeta).SendHeader(Header{"key": "val"}) + test.Assert(t, err == nil, err) + + // send data + for { + req := new(TestRequest) + err := st.RecvMsg(ctx, req) + if errors.Is(err, io.EOF) { + t.Logf("server stream eof") + return + } + test.Assert(t, err == nil, err) + t.Logf("server recv msg: %v", req) + + res := req + err = st.SendMsg(ctx, res) + if errors.Is(err, io.EOF) { + return + } + test.Assert(t, err == nil, err) + t.Logf("server send msg: %v", res) + } + }(ss) + } + }() + + return context.WithValue(ctx, "trans", trans) + }), netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { + t.Logf("OnDisconnect started") + defer t.Logf("OnDisconnect finished") + + atomic.AddInt32(&connDone, -1) + })) + go func() { + err = svr.Serve(ln) + test.Assert(t, err == nil, err) + }() + defer svr.Shutdown(context.Background()) + test.WaitServerStart(addr) + + // Client + ctx := context.Background() + atomic.AddInt32(&connDone, 1) + conn, err := netpoll.DialConnection("tcp", addr, time.Second) + test.Assert(t, err == nil, err) + trans := newTransport(clientTransport, sinfo, conn) + + var wg sync.WaitGroup + for sid := 1; sid <= 1; sid++ { + wg.Add(1) + atomic.AddInt32(&streamDone, 1) + go func(sid int) { + defer wg.Done() + + // send header + s, err := trans.newStream(ctx, method, map[string]string{}) + test.Assert(t, err == nil, err) + + cs := newClientStream(s) + t.Logf("client stream[%d] created", sid) + + // recv header + hd, err := cs.Header() + test.Assert(t, err == nil, err) + test.Assert(t, hd["key"] == "val", hd) + t.Logf("client stream[%d] recv header=%v", sid, hd) + + // send and recv data + for i := 0; i < 3; i++ { + req := new(TestRequest) + req.A = 12345 + req.B = "hello" + res := new(TestResponse) + err = cs.SendMsg(ctx, req) + t.Logf("client stream[%d] send msg: %v", sid, req) + test.Assert(t, err == nil, err) + err = cs.RecvMsg(ctx, res) + t.Logf("client stream[%d] recv msg: %v", sid, res) + test.Assert(t, err == nil, err) + test.Assert(t, req.A == res.A, res) + test.Assert(t, req.B == res.B, res) + } + + // send trailer(trailer is stored in ctx) + err = cs.CloseSend(ctx) + test.Assert(t, err == nil, err) + t.Logf("client stream[%d] close send", sid) + + // recv trailer + tl, err := cs.Trailer() + test.Assert(t, err == nil, err) + test.Assert(t, tl["key"] == "val", tl) + t.Logf("client stream[%d] recv trailer=%v", sid, tl) + }(sid) + } + wg.Wait() + for atomic.LoadInt32(&streamDone) != 0 { + t.Logf("wait all streams closed") + time.Sleep(time.Millisecond * 10) + } + + // close conn + err = trans.close() + test.Assert(t, err == nil, err) + err = ln.Close() + test.Assert(t, err == nil, err) + for atomic.LoadInt32(&connDone) != 0 { + time.Sleep(time.Millisecond * 10) + t.Logf("wait all connections closed") + } +} diff --git a/pkg/streamx/provider/ttstream/ttstream_client_test.go b/pkg/streamx/provider/ttstream/ttstream_client_test.go new file mode 100644 index 0000000000..2aedc8f488 --- /dev/null +++ b/pkg/streamx/provider/ttstream/ttstream_client_test.go @@ -0,0 +1,322 @@ +package ttstream_test + +import ( + "context" + "errors" + "io" + "log" + "net/http" + _ "net/http/pprof" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/client/streamxclient" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" + "github.com/cloudwego/kitex/server" + "github.com/cloudwego/kitex/server/streamxserver" + "github.com/cloudwego/kitex/transport" + "github.com/cloudwego/netpoll" +) + +func TestTTHeaderStreaming(t *testing.T) { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + + var addr = test.GetLocalAddress() + ln, err := netpoll.CreateListener("tcp", addr) + test.Assert(t, err == nil, err) + defer ln.Close() + + // create server + var serverStreamCount int32 + waitServerStreamDone := func() { + for atomic.LoadInt32(&serverStreamCount) != 0 { + t.Logf("waitServerStreamDone: %d", atomic.LoadInt32(&serverStreamCount)) + time.Sleep(time.Millisecond * 100) + } + } + methodCount := map[string]int{} + serverRecvCount := map[string]int{} + serverSendCount := map[string]int{} + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + // register pingpong service + err = svr.RegisterService(pingpongServiceInfo, new(pingpongService)) + test.Assert(t, err == nil, err) + // register streamingService as ttstreaam provider + sp, err := ttstream.NewServerProvider(streamingServiceInfo) + test.Assert(t, err == nil, err) + err = svr.RegisterService( + streamingServiceInfo, + new(streamingService), + streamxserver.WithProvider(sp), + streamxserver.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { + return func(ctx context.Context, stream streamx.Stream, res any) (err error) { + err = next(ctx, stream, res) + if err == nil { + serverRecvCount[stream.Method()]++ + } else { + log.Printf("server recv middleware err=%v", err) + } + return err + } + }), + streamxserver.WithStreamSendMiddleware(func(next streamx.StreamSendEndpoint) streamx.StreamSendEndpoint { + return func(ctx context.Context, stream streamx.Stream, req any) (err error) { + err = next(ctx, stream, req) + if err == nil { + serverSendCount[stream.Method()]++ + } else { + log.Printf("server send middleware err=%v", err) + } + return err + } + }), + streamxserver.WithStreamMiddleware( + // middleware example: server streaming mode + func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + log.Printf("Server middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs) + test.Assert(t, streamArgs.Stream() != nil) + test.Assert(t, ValidateMetadata(ctx)) + + log.Printf("Server handler start") + switch streamArgs.Stream().Mode() { + case streamx.StreamingUnary: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() != nil) + case streamx.StreamingClient: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() != nil) + case streamx.StreamingServer: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingBidirectional: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + } + test.Assert(t, err == nil, err) + methodCount[streamArgs.Stream().Method()]++ + log.Printf("Server handler end") + + log.Printf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) + atomic.AddInt32(&serverStreamCount, 1) + return nil + } + }, + ), + ) + test.Assert(t, err == nil, err) + go func() { + err := svr.Run() + test.Assert(t, err == nil, err) + }() + defer svr.Stop() + test.WaitServerStart(addr) + + // create client + pingpongClient, err := NewPingPongClient( + "kitex.service.pingpong", + client.WithHostPorts(addr), + client.WithTransportProtocol(transport.TTHeaderFramed), + client.WithPayloadCodec(thrift.NewThriftCodecWithConfig(thrift.FastRead|thrift.FastWrite|thrift.EnableSkipDecoder)), + ) + test.Assert(t, err == nil, err) + streamClient, err := NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + streamxclient.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { + return func(ctx context.Context, stream streamx.Stream, res any) (err error) { + err = next(ctx, stream, res) + log.Printf("Client recv middleware %v", res) + return err + } + }), + streamxclient.WithStreamSendMiddleware(func(next streamx.StreamSendEndpoint) streamx.StreamSendEndpoint { + return func(ctx context.Context, stream streamx.Stream, req any) (err error) { + err = next(ctx, stream, req) + log.Printf("Client send middleware %v", req) + return err + } + }), + streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + // validate ctx + test.Assert(t, ValidateMetadata(ctx)) + + log.Printf("Client middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, err == nil, err) + log.Printf("Client middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) + + test.Assert(t, streamArgs.Stream() != nil) + switch streamArgs.Stream().Mode() { + case streamx.StreamingUnary: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() != nil) + case streamx.StreamingClient: + test.Assert(t, reqArgs.Req() == nil, reqArgs.Req()) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingServer: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingBidirectional: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + } + return err + } + }), + ) + test.Assert(t, err == nil, err) + + // prepare metainfo + ctx := context.Background() + ctx = SetMetadata(ctx) + + t.Logf("=== PingPong ===") + req := new(Request) + req.Message = "PingPong" + res, err := pingpongClient.PingPong(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, req.Message == res.Message, res) + + t.Logf("=== Unary ===") + req = new(Request) + req.Type = 10000 + req.Message = "Unary" + res, err = streamClient.Unary(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, req.Type == res.Type, res.Type) + test.Assert(t, req.Message == res.Message, res.Message) + test.Assert(t, serverRecvCount["Unary"] == 1, serverRecvCount) + test.Assert(t, serverSendCount["Unary"] == 1, serverSendCount) + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + + // client stream + round := 5 + t.Logf("=== ClientStream ===") + cs, err := streamClient.ClientStream(ctx) + test.Assert(t, err == nil, err) + for i := 0; i < round; i++ { + req := new(Request) + req.Type = int32(i) + req.Message = "ClientStream" + err = cs.Send(ctx, req) + test.Assert(t, err == nil, err) + } + res, err = cs.CloseAndRecv(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == "ClientStream", res.Message) + t.Logf("Client ClientStream CloseAndRecv: %v", res) + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + test.Assert(t, serverRecvCount["ClientStream"] == round, serverRecvCount) + test.Assert(t, serverSendCount["ClientStream"] == 1, serverSendCount) + cs = nil + runtime.GC() + + // server stream + t.Logf("=== ServerStream ===") + req = new(Request) + req.Message = "ServerStream" + ss, err := streamClient.ServerStream(ctx, req) + test.Assert(t, err == nil, err) + // server stream recv header + hd, err := ss.Header() + test.Assert(t, err == nil, err) + t.Logf("Client ServerStream recv header: %v", hd) + test.DeepEqual(t, hd["key1"], "val1") + test.DeepEqual(t, hd["key2"], "val2") + received := 0 + for { + res, err := ss.Recv(ctx) + if errors.Is(err, io.EOF) { + break + } + test.Assert(t, err == nil, err) + received++ + t.Logf("Client ServerStream recv: %v", res) + } + err = ss.CloseSend(ctx) + test.Assert(t, err == nil, err) + // server stream recv trailer + tl, err := ss.Trailer() + test.Assert(t, err == nil, err) + t.Logf("Client ServerStream recv trailer: %v", tl) + test.DeepEqual(t, tl["key1"], "val1") + test.DeepEqual(t, tl["key2"], "val2") + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + test.Assert(t, serverRecvCount["ServerStream"] == 1, serverRecvCount) + test.Assert(t, serverSendCount["ServerStream"] == received, serverSendCount) + ss = nil + runtime.GC() + + // bidi stream + round = 5 + t.Logf("=== BidiStream ===") + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + msg := "BidiStream" + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < round; i++ { + req := new(Request) + req.Message = msg + err := bs.Send(ctx, req) + test.Assert(t, err == nil, err) + } + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + }() + go func() { + defer wg.Done() + i := 0 + for { + res, err := bs.Recv(ctx) + if errors.Is(err, io.EOF) { + break + } + i++ + test.Assert(t, err == nil, err) + test.Assert(t, msg == res.Message, res.Message) + } + test.Assert(t, i == round, i) + }() + wg.Wait() + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + test.Assert(t, serverRecvCount["BidiStream"] == round, serverRecvCount) + test.Assert(t, serverSendCount["BidiStream"] == round, serverSendCount) + bs = nil + runtime.GC() + + streamClient = nil +} diff --git a/pkg/streamx/provider/ttstream/ttstream_common_test.go b/pkg/streamx/provider/ttstream/ttstream_common_test.go new file mode 100644 index 0000000000..9d424c1826 --- /dev/null +++ b/pkg/streamx/provider/ttstream/ttstream_common_test.go @@ -0,0 +1,44 @@ +package ttstream_test + +import ( + "context" + "github.com/bytedance/gopkg/cloud/metainfo" +) + +var persistKVs = map[string]string{ + "p1": "v1", + "p2": "v2", + "p3": "v3", +} + +var transitKVs = map[string]string{ + "t1": "v1", + "t2": "v2", + "t3": "v3", +} + +func SetMetadata(ctx context.Context) context.Context { + for k, v := range persistKVs { + ctx = metainfo.WithPersistentValue(ctx, k, v) + } + for k, v := range transitKVs { + ctx = metainfo.WithValue(ctx, k, v) + } + return ctx +} + +func ValidateMetadata(ctx context.Context) bool { + for k, v := range persistKVs { + _v, _ := metainfo.GetPersistentValue(ctx, k) + if _v != v { + return false + } + } + for k, v := range transitKVs { + _v, _ := metainfo.GetValue(ctx, k) + if _v != v { + return false + } + } + return true +} diff --git a/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go b/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go new file mode 100644 index 0000000000..0563679310 --- /dev/null +++ b/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go @@ -0,0 +1,451 @@ +// Code generated by Kitex v1.16.4. DO NOT EDIT. + +package ttstream_test + +import ( + "bytes" + "fmt" + "reflect" + "strings" + + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + + "github.com/apache/thrift/lib/go/thrift" + kutils "github.com/cloudwego/kitex/pkg/utils" +) + +// unused protection +var ( + _ = fmt.Formatter(nil) + _ = (*bytes.Buffer)(nil) + _ = (*strings.Builder)(nil) + _ = reflect.Type(nil) + _ = thrift.TProtocol(nil) + _ = bthrift.BinaryWriter(nil) +) + +var fieldIDToName_Request = map[int16]string{ + 1: "Type", + 2: "Message", +} + +var fieldIDToName_Response = map[int16]string{ + 1: "Type", + 2: "Message", +} + +type Request struct { + Type int32 `thrift:"Type,1" frugal:"1,default,i32" json:"Type"` + Message string `thrift:"Message,2" frugal:"2,default,string" json:"Message"` +} + +type Response struct { + Type int32 `thrift:"Type,1" frugal:"1,default,i32" json:"Type"` + Message string `thrift:"Message,2" frugal:"2,default,string" json:"Message"` +} + +type ServerPingPongArgs struct { + Req *Request `thrift:"req,1" frugal:"1,default,Request" json:"req"` +} + +type ServerPingPongResult struct { + Success *Response `thrift:"success,0,optional" frugal:"0,optional,Response" json:"success,omitempty"` +} + +func NewServerPingPongArgs() *ServerPingPongArgs { + return &ServerPingPongArgs{} +} + +func NewServerPingPongResult() *ServerPingPongResult { + return &ServerPingPongResult{} +} + +func (p ServerPingPongResult) GetSuccess() *Response { + return p.Success +} + +func (p *Request) FastRead(buf []byte) (int, error) { + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + _, l, err = bthrift.Binary.ReadStructBegin(buf) + offset += l + if err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + + l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldEndError + } + } + l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) + offset += l + if err != nil { + goto ReadStructEndError + } + + return offset, nil +ReadStructBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Request[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +ReadFieldEndError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *Request) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field int32 + if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + + _field = v + + } + p.Type = _field + return offset, nil +} + +func (p *Request) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field string + if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + + _field = v + + } + p.Message = _field + return offset, nil +} + +// for compatibility +func (p *Request) FastWrite(buf []byte) int { + return 0 +} + +func (p *Request) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { + offset := 0 + offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Request") + if p != nil { + offset += p.fastWriteField1(buf[offset:], binaryWriter) + offset += p.fastWriteField2(buf[offset:], binaryWriter) + } + offset += bthrift.Binary.WriteFieldStop(buf[offset:]) + offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + return offset +} + +func (p *Request) BLength() int { + l := 0 + l += bthrift.Binary.StructBeginLength("Request") + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += bthrift.Binary.FieldStopLength() + l += bthrift.Binary.StructEndLength() + return l +} + +func (p *Request) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { + offset := 0 + offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Type", thrift.I32, 1) + offset += bthrift.Binary.WriteI32(buf[offset:], p.Type) + offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + return offset +} + +func (p *Request) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) int { + offset := 0 + offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Message", thrift.STRING, 2) + offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Message) + offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + return offset +} + +func (p *Request) field1Length() int { + l := 0 + l += bthrift.Binary.FieldBeginLength("Type", thrift.I32, 1) + l += bthrift.Binary.I32Length(p.Type) + l += bthrift.Binary.FieldEndLength() + return l +} + +func (p *Request) field2Length() int { + l := 0 + l += bthrift.Binary.FieldBeginLength("Message", thrift.STRING, 2) + l += bthrift.Binary.StringLengthNocopy(p.Message) + l += bthrift.Binary.FieldEndLength() + return l +} + +func (p *Request) DeepCopy(s interface{}) error { + src, ok := s.(*Request) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + p.Type = src.Type + + if src.Message != "" { + p.Message = kutils.StringDeepCopy(src.Message) + } + + return nil +} + +func (p *Response) FastRead(buf []byte) (int, error) { + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + _, l, err = bthrift.Binary.ReadStructBegin(buf) + offset += l + if err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.I32 { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + + l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldEndError + } + } + l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) + offset += l + if err != nil { + goto ReadStructEndError + } + + return offset, nil +ReadStructBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Response[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +ReadFieldEndError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *Response) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field int32 + if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + + _field = v + + } + p.Type = _field + return offset, nil +} + +func (p *Response) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field string + if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + + _field = v + + } + p.Message = _field + return offset, nil +} + +// for compatibility +func (p *Response) FastWrite(buf []byte) int { + return 0 +} + +func (p *Response) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { + offset := 0 + offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Response") + if p != nil { + offset += p.fastWriteField1(buf[offset:], binaryWriter) + offset += p.fastWriteField2(buf[offset:], binaryWriter) + } + offset += bthrift.Binary.WriteFieldStop(buf[offset:]) + offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + return offset +} + +func (p *Response) BLength() int { + l := 0 + l += bthrift.Binary.StructBeginLength("Response") + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += bthrift.Binary.FieldStopLength() + l += bthrift.Binary.StructEndLength() + return l +} + +func (p *Response) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { + offset := 0 + offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Type", thrift.I32, 1) + offset += bthrift.Binary.WriteI32(buf[offset:], p.Type) + offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + return offset +} + +func (p *Response) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) int { + offset := 0 + offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Message", thrift.STRING, 2) + offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Message) + offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + return offset +} + +func (p *Response) field1Length() int { + l := 0 + l += bthrift.Binary.FieldBeginLength("Type", thrift.I32, 1) + l += bthrift.Binary.I32Length(p.Type) + l += bthrift.Binary.FieldEndLength() + return l +} + +func (p *Response) field2Length() int { + l := 0 + l += bthrift.Binary.FieldBeginLength("Message", thrift.STRING, 2) + l += bthrift.Binary.StringLengthNocopy(p.Message) + l += bthrift.Binary.FieldEndLength() + return l +} + +func (p *Response) DeepCopy(s interface{}) error { + src, ok := s.(*Response) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + p.Type = src.Type + + if src.Message != "" { + p.Message = kutils.StringDeepCopy(src.Message) + } + + return nil +} diff --git a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go b/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go new file mode 100644 index 0000000000..c47a464620 --- /dev/null +++ b/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go @@ -0,0 +1,205 @@ +package ttstream_test + +import ( + "context" + + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/client/streamxclient" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" + "github.com/cloudwego/kitex/server" + "github.com/cloudwego/kitex/server/streamxserver" +) + +// === gen code === + +// --- Define Header and Trailer type --- +type ClientStreamingServer[Req, Res any] streamx.ClientStreamingServer[ttstream.Header, ttstream.Trailer, Req, Res] +type ServerStreamingServer[Res any] streamx.ServerStreamingServer[ttstream.Header, ttstream.Trailer, Res] +type BidiStreamingServer[Req, Res any] streamx.BidiStreamingServer[ttstream.Header, ttstream.Trailer, Req, Res] + +type ClientStreamingClient[Req, Res any] streamx.ClientStreamingClient[ttstream.Header, ttstream.Trailer, Req, Res] +type ServerStreamingClient[Res any] streamx.ServerStreamingClient[ttstream.Header, ttstream.Trailer, Res] +type BidiStreamingClient[Req, Res any] streamx.BidiStreamingClient[ttstream.Header, ttstream.Trailer, Req, Res] + +// --- Define Service Method handler --- +var pingpongServiceInfo = &serviceinfo.ServiceInfo{ + ServiceName: "kitex.service.pingpong", + PayloadCodec: serviceinfo.Thrift, + HandlerType: (*PingPongServerInterface)(nil), + Methods: map[string]serviceinfo.MethodInfo{ + "PingPong": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + realArg := reqArgs.(*ServerPingPongArgs) + realResult := resArgs.(*ServerPingPongResult) + success, err := handler.(PingPongServerInterface).PingPong(ctx, realArg.Req) + if err != nil { + return err + } + realResult.Success = success + return nil + }, + func() interface{} { return NewServerPingPongArgs() }, + func() interface{} { return NewServerPingPongResult() }, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingNone), + ), + }, + Extra: map[string]interface{}{"streaming": false}, +} + +var streamingServiceInfo = &serviceinfo.ServiceInfo{ + ServiceName: "kitex.service.streaming", + Methods: map[string]serviceinfo.MethodInfo{ + "Unary": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + ctx, serviceinfo.StreamingUnary, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingUnary), + ), + "ClientStream": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + ctx, serviceinfo.StreamingClient, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingClient), + ), + "ServerStream": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + ctx, serviceinfo.StreamingServer, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingServer), + ), + "BidiStream": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), + }, + Extra: map[string]interface{}{"streaming": true}, +} + +// --- Define RegisterService interface --- +func RegisterService(svr server.Server, handler StreamingServerInterface, opts ...server.RegisterOption) error { + return svr.RegisterService(streamingServiceInfo, handler, opts...) +} + +// --- Define New Client interface --- +func NewPingPongClient(destService string, opts ...client.Option) (PingPongClientInterface, error) { + var options []client.Option + options = append(options, client.WithDestService(destService)) + options = append(options, opts...) + cli, err := client.NewClient(pingpongServiceInfo, options...) + if err != nil { + return nil, err + } + kc := &kClient{caller: cli} + return kc, nil +} + +func NewStreamingClient(destService string, opts ...streamxclient.Option) (StreamingClientInterface, error) { + var options []streamxclient.Option + options = append(options, streamxclient.WithDestService(destService)) + options = append(options, opts...) + cp, err := ttstream.NewClientProvider(streamingServiceInfo) + if err != nil { + return nil, err + } + options = append(options, streamxclient.WithProvider(cp)) + cli, err := streamxclient.NewClient(streamingServiceInfo, options...) + if err != nil { + return nil, err + } + kc := &kClient{streamer: cli, caller: cli.(client.Client)} + return kc, nil +} + +// --- Define Server Implementation Interface --- +type PingPongServerInterface interface { + PingPong(ctx context.Context, req *Request) (*Response, error) +} +type StreamingServerInterface interface { + Unary(ctx context.Context, req *Request) (*Response, error) + ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (*Response, error) + ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error + BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error +} + +// --- Define Client Implementation Interface --- +type PingPongClientInterface interface { + PingPong(ctx context.Context, req *Request) (r *Response, err error) +} + +type StreamingClientInterface interface { + Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) + ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream ClientStreamingClient[Request, Response], err error) + ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( + stream ServerStreamingClient[Response], err error) + BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream BidiStreamingClient[Request, Response], err error) +} + +// --- Define Client Implementation --- +var _ StreamingClientInterface = (*kClient)(nil) +var _ PingPongClientInterface = (*kClient)(nil) + +type kClient struct { + caller client.Client + streamer streamxclient.Client +} + +func (c *kClient) PingPong(ctx context.Context, req *Request) (r *Response, err error) { + var _args ServerPingPongArgs + _args.Req = req + var _result ServerPingPongResult + if err = c.caller.Call(ctx, "PingPong", &_args, &_result); err != nil { + return + } + return _result.GetSuccess(), nil +} + +func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { + res := new(Response) + _, err := streamxclient.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + ctx, c.streamer, serviceinfo.StreamingUnary, "Unary", req, res, callOptions...) + if err != nil { + return nil, err + } + return res, nil +} + +func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream ClientStreamingClient[Request, Response], err error) { + return streamxclient.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + ctx, c.streamer, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) +} + +func (c *kClient) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( + stream ServerStreamingClient[Response], err error) { + return streamxclient.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + ctx, c.streamer, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) +} + +func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream BidiStreamingClient[Request, Response], err error) { + return streamxclient.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) +} diff --git a/pkg/streamx/provider/ttstream/ttstream_server_test.go b/pkg/streamx/provider/ttstream/ttstream_server_test.go new file mode 100644 index 0000000000..91e1559108 --- /dev/null +++ b/pkg/streamx/provider/ttstream/ttstream_server_test.go @@ -0,0 +1,81 @@ +package ttstream_test + +import ( + "context" + "io" + "log" +) + +type pingpongService struct{} +type streamingService struct{} + +func (si *pingpongService) PingPong(ctx context.Context, req *Request) (*Response, error) { + resp := &Response{Type: req.Type, Message: req.Message} + log.Printf("Server PingPong: req={%v} resp={%v}", req, resp) + return resp, nil +} + +func (si *streamingService) Unary(ctx context.Context, req *Request) (*Response, error) { + resp := &Response{Type: req.Type, Message: req.Message} + log.Printf("Server Unary: req={%v} resp={%v}", req, resp) + return resp, nil +} + +func (si *streamingService) ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (res *Response, err error) { + var msg string + defer log.Printf("Server ClientStream end") + for { + req, err := stream.Recv(ctx) + if err == io.EOF { + res = new(Response) + res.Message = msg + return res, nil + } + if err != nil { + return nil, err + } + msg = req.Message + log.Printf("Server ClientStream: req={%v}", req) + } +} + +func (si *streamingService) ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error { + log.Printf("Server ServerStream: req={%v}", req) + + _ = stream.SetHeader(map[string]string{"key1": "val1"}) + _ = stream.SendHeader(map[string]string{"key2": "val2"}) + _ = stream.SetTrailer(map[string]string{"key1": "val1"}) + _ = stream.SetTrailer(map[string]string{"key2": "val2"}) + + for i := 0; i < 3; i++ { + resp := new(Response) + resp.Type = int32(i) + resp.Message = req.Message + err := stream.Send(ctx, resp) + if err != nil { + return err + } + log.Printf("Server ServerStream: send resp={%v}", resp) + } + return nil +} + +func (si *streamingService) BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error { + for { + req, err := stream.Recv(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + + resp := new(Response) + resp.Message = req.Message + err = stream.Send(ctx, resp) + if err != nil { + return err + } + log.Printf("Server BidiStream: req={%v} resp={%v}", req, resp) + } +} diff --git a/pkg/streamx/server_provider.go b/pkg/streamx/server_provider.go new file mode 100644 index 0000000000..33c551aaf7 --- /dev/null +++ b/pkg/streamx/server_provider.go @@ -0,0 +1,51 @@ +package streamx + +import ( + "context" + "net" +) + +/* Hot it works + +serverProvider := xxx.NewServerProvider(xxx.WithXXX()...) +server := {user_gencode}.NewServer({kitex_server}.WithServerProvider(serverProvider)) + => {kitex_server}.NewServer(WithTransHandlerFactory({kitex_server}.NewSvrTransHandlerFactory()) + => {kitex_server}.initMiddlewares() + => {kitex_server}.initServerTransHandler() + + +{kitex_server}.ServerTransHandler.OnRead +=> serverProvider.Available(conn) +=> serverProvider.OnActive(conn) +=> stream := serverProvider.OnStream(conn) + => {kitex_server}.internalProvider.OnStream(conn) + => serverProvider.OnStream(conn) + +res := stream.Recv(...) + => {kitex_server}.internalProvider.Stream.Recv(...) : run middlewares + => serverProvider.Stream.Recv(...) + +stream.Close() - server handler return +*/ + +/* Hot it works +- NewServer 时,初始化 ServerProvider,并注册 streamx.ServerTransHandler +- 连接进来的时候,detection trans handler 会转发调用 streamx.ServerTransHandler +- streamx.ServerTransHandler 负责调用 ServerProvider 的相关方法 + +后续读写都发生在 stream 提供的方法中 + + +*/ + +type ServerProvider interface { + // Available detect if provider can process conn from its first N bytes + Available(ctx context.Context, conn net.Conn) bool // ProtocolMath + // OnActive called when conn connected + OnActive(ctx context.Context, conn net.Conn) (context.Context, error) + OnInactive(ctx context.Context, conn net.Conn) (context.Context, error) + // OnStream should read conn data and return a server stream + OnStream(ctx context.Context, conn net.Conn) (context.Context, ServerStream, error) + // OnStreamFinish should be called when user server handler returned, typically provide should close the stream + OnStreamFinish(ctx context.Context, ss ServerStream) (context.Context, error) +} diff --git a/pkg/streamx/server_provider_internal.go b/pkg/streamx/server_provider_internal.go new file mode 100644 index 0000000000..2342aecbc0 --- /dev/null +++ b/pkg/streamx/server_provider_internal.go @@ -0,0 +1,25 @@ +package streamx + +import ( + "context" + "net" +) + +func NewServerProvider(ss ServerProvider) ServerProvider { + if _, ok := ss.(*internalServerProvider); ok { + return ss + } + return internalServerProvider{ServerProvider: ss} +} + +type internalServerProvider struct { + ServerProvider +} + +func (p internalServerProvider) OnStream(ctx context.Context, conn net.Conn) (context.Context, ServerStream, error) { + ctx, ss, err := p.ServerProvider.OnStream(ctx, conn) + if err != nil { + return nil, nil, err + } + return ctx, ss, nil +} diff --git a/pkg/streamx/stream.go b/pkg/streamx/stream.go new file mode 100644 index 0000000000..b9abdd0bff --- /dev/null +++ b/pkg/streamx/stream.go @@ -0,0 +1,272 @@ +package streamx + +import ( + "context" + + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +var _ ServerStreamingClient[int, int, int] = (*GenericClientStream[int, int, int, int])(nil) +var _ ClientStreamingClient[int, int, int, int] = (*GenericClientStream[int, int, int, int])(nil) +var _ BidiStreamingClient[int, int, int, int] = (*GenericClientStream[int, int, int, int])(nil) +var _ ServerStreamingServer[int, int, int] = (*GenericServerStream[int, int, int, int])(nil) +var _ ClientStreamingServer[int, int, int, int] = (*GenericServerStream[int, int, int, int])(nil) +var _ BidiStreamingServer[int, int, int, int] = (*GenericServerStream[int, int, int, int])(nil) + +type StreamingMode = serviceinfo.StreamingMode + +/* Streaming Mode +--------------- [Unary Streaming] --------------- +--------------- (Req) returns (Res) --------------- +client.Send(req) === req ==> server.Recv(req) +client.Recv(res) <== res === server.Send(res) + + +------------------- [Client Streaming] ------------------- +--------------- (stream Req) returns (Res) --------------- +client.Send(req) === req ==> server.Recv(req) + ... +client.Send(req) === req ==> server.Recv(req) + +client.CloseSend() === EOF ==> server.Recv(EOF) +client.Recv(res) <== res === server.SendAndClose(res) +** OR +client.CloseAndRecv(res) === EOF ==> server.Recv(EOF) + <== res === server.SendAndClose(res) + + +------------------- [Server Streaming] ------------------- +---------- (Request) returns (stream Response) ---------- +client.Send(req) === req ==> server.Recv(req) +client.Recv(res) <== res === server.Send(req) + ... +client.Recv(res) <== res === server.Send(req) +client.Recv(EOF) <== EOF === server handler return + + +----------- [Bidirectional Streaming] ----------- +--- (stream Request) returns (stream Response) --- +* goroutine 1 * +client.Send(req) === req ==> server.Recv(req) + ... +client.Send(req) === req ==> server.Recv(req) +client.CloseSend() === EOF ==> server.Recv(EOF) + +* goroutine 2 * +client.Recv(res) <== res === server.Send(req) + ... +client.Recv(res) <== res === server.Send(req) +client.Recv(EOF) <== EOF === server handler return +*/ + +const ( + StreamingNone = serviceinfo.StreamingNone + StreamingUnary = serviceinfo.StreamingUnary + StreamingClient = serviceinfo.StreamingClient + StreamingServer = serviceinfo.StreamingServer + StreamingBidirectional = serviceinfo.StreamingBidirectional +) + +type Stream interface { + Mode() StreamingMode + Service() string + Method() string + SendMsg(ctx context.Context, m any) error + RecvMsg(ctx context.Context, m any) error +} + +type ClientStream interface { + Stream + CloseSend(ctx context.Context) error +} + +type ServerStream interface { + Stream +} + +// client 必须通过 metainfo.WithValue(ctx, ..) 给下游传递信息 +// client 必须通过 metainfo.GetValue(ctx, ..) 拿到当前 server 的透传信息 +// client 必须通过 Header() 拿到下游 server 的透传信息 +type ClientStreamMetadata[Header, Trailer any] interface { + Header() (Header, error) + Trailer() (Trailer, error) +} + +// server 可以通过 Set/SendXXX 给上游回传信息 +type ServerStreamMetadata[Header, Trailer any] interface { + SetHeader(hd Header) error + SendHeader(hd Header) error + SetTrailer(hd Trailer) error +} + +type ServerStreamingClient[Header, Trailer, Res any] interface { + Recv(ctx context.Context) (*Res, error) + ClientStream + ClientStreamMetadata[Header, Trailer] +} + +type ServerStreamingServer[Header, Trailer, Res any] interface { + Send(ctx context.Context, res *Res) error + ServerStream + ServerStreamMetadata[Header, Trailer] +} + +type ClientStreamingClient[Header, Trailer, Req, Res any] interface { + Send(ctx context.Context, req *Req) error + CloseAndRecv(ctx context.Context) (*Res, error) + ClientStream + ClientStreamMetadata[Header, Trailer] +} + +type ClientStreamingServer[Header, Trailer, Req, Res any] interface { + Recv(ctx context.Context) (*Req, error) + //SendAndClose(ctx context.Context, res *Res) error + ServerStream + ServerStreamMetadata[Header, Trailer] +} + +type BidiStreamingClient[Header, Trailer, Req, Res any] interface { + Send(ctx context.Context, req *Req) error + Recv(ctx context.Context) (*Res, error) + ClientStream + ClientStreamMetadata[Header, Trailer] +} + +type BidiStreamingServer[Header, Trailer, Req, Res any] interface { + Recv(ctx context.Context) (*Req, error) + Send(ctx context.Context, res *Res) error + ServerStream + ServerStreamMetadata[Header, Trailer] +} + +type GenericStreamIOMiddlewareSetter interface { + SetStreamSendEndpoint(e StreamSendEndpoint) + SetStreamRecvEndpoint(e StreamSendEndpoint) +} + +func NewGenericClientStream[Header, Trailer, Req, Res any](cs ClientStream) *GenericClientStream[Header, Trailer, Req, Res] { + return &GenericClientStream[Header, Trailer, Req, Res]{ + ClientStream: cs, + ClientStreamMetadata: cs.(ClientStreamMetadata[Header, Trailer]), + } +} + +type GenericClientStream[Header, Trailer, Req, Res any] struct { + ClientStream + ClientStreamMetadata[Header, Trailer] + StreamSendMiddleware + StreamRecvMiddleware +} + +func (x *GenericClientStream[Header, Trailer, Req, Res]) SetStreamSendMiddleware(e StreamSendMiddleware) { + x.StreamSendMiddleware = e +} + +func (x *GenericClientStream[Header, Trailer, Req, Res]) SetStreamRecvMiddleware(e StreamRecvMiddleware) { + x.StreamRecvMiddleware = e +} + +func (x *GenericClientStream[Header, Trailer, Req, Res]) SendMsg(ctx context.Context, m any) error { + if x.StreamSendMiddleware != nil { + return x.StreamSendMiddleware(streamSendNext)(ctx, x.ClientStream, m) + } + return x.ClientStream.SendMsg(ctx, m) +} + +func (x *GenericClientStream[Header, Trailer, Req, Res]) RecvMsg(ctx context.Context, m any) (err error) { + if x.StreamRecvMiddleware != nil { + err = x.StreamRecvMiddleware(streamRecvNext)(ctx, x.ClientStream, m) + } else { + err = x.ClientStream.RecvMsg(ctx, m) + } + return err +} + +func (x *GenericClientStream[Header, Trailer, Req, Res]) Send(ctx context.Context, m *Req) error { + return x.SendMsg(ctx, m) +} + +func (x *GenericClientStream[Header, Trailer, Req, Res]) Recv(ctx context.Context) (m *Res, err error) { + m = new(Res) + if err = x.RecvMsg(ctx, m); err != nil { + return nil, err + } + return m, nil +} + +func (x *GenericClientStream[Header, Trailer, Req, Res]) CloseAndRecv(ctx context.Context) (*Res, error) { + if err := x.ClientStream.CloseSend(ctx); err != nil { + return nil, err + } + return x.Recv(ctx) +} + +func NewGenericServerStream[Header, Trailer, Req, Res any](ss ServerStream) *GenericServerStream[Header, Trailer, Req, Res] { + return &GenericServerStream[Header, Trailer, Req, Res]{ + ServerStream: ss, + ServerStreamMetadata: ss.(ServerStreamMetadata[Header, Trailer]), + } +} + +type GenericServerStream[Header, Trailer, Req, Res any] struct { + ServerStream + ServerStreamMetadata[Header, Trailer] + StreamSendMiddleware + StreamRecvMiddleware +} + +func (x *GenericServerStream[Header, Trailer, Req, Res]) SetStreamSendMiddleware(e StreamSendMiddleware) { + x.StreamSendMiddleware = e +} + +func (x *GenericServerStream[Header, Trailer, Req, Res]) SetStreamRecvMiddleware(e StreamRecvMiddleware) { + x.StreamRecvMiddleware = e +} + +func (x *GenericServerStream[Header, Trailer, Req, Res]) SendMsg(ctx context.Context, m any) error { + if x.StreamSendMiddleware != nil { + return x.StreamSendMiddleware(streamSendNext)(ctx, x.ServerStream, m) + } + return x.ServerStream.SendMsg(ctx, m) +} + +func (x *GenericServerStream[Header, Trailer, Req, Res]) RecvMsg(ctx context.Context, m any) (err error) { + if x.StreamRecvMiddleware != nil { + err = x.StreamRecvMiddleware(streamRecvNext)(ctx, x.ServerStream, m) + } else { + err = x.ServerStream.RecvMsg(ctx, m) + } + return err +} + +func (x *GenericServerStream[Header, Trailer, Req, Res]) Send(ctx context.Context, m *Res) error { + if x.StreamSendMiddleware != nil { + return x.StreamSendMiddleware(streamSendNext)(ctx, x.ServerStream, m) + } + return x.ServerStream.SendMsg(ctx, m) +} + +func (x *GenericServerStream[Header, Trailer, Req, Res]) SendAndClose(ctx context.Context, m *Res) error { + return x.Send(ctx, m) +} + +func (x *GenericServerStream[Header, Trailer, Req, Res]) Recv(ctx context.Context) (m *Req, err error) { + m = new(Req) + if x.StreamRecvMiddleware != nil { + err = x.StreamRecvMiddleware(streamRecvNext)(ctx, x.ServerStream, m) + } else { + err = x.ServerStream.RecvMsg(ctx, m) + } + if err != nil { + return nil, err + } + return m, nil +} + +func streamSendNext(ctx context.Context, stream Stream, msg any) (err error) { + return stream.SendMsg(ctx, msg) +} + +func streamRecvNext(ctx context.Context, stream Stream, msg any) (err error) { + return stream.RecvMsg(ctx, msg) +} diff --git a/pkg/streamx/stream_args.go b/pkg/streamx/stream_args.go new file mode 100644 index 0000000000..1c0482ff30 --- /dev/null +++ b/pkg/streamx/stream_args.go @@ -0,0 +1,104 @@ +package streamx + +import ( + "context" + "errors" +) + +type StreamCtxKey struct{} + +func WithStreamArgsContext(ctx context.Context, args StreamArgs) context.Context { + ctx = context.WithValue(ctx, StreamCtxKey{}, args) + return ctx +} + +func GetStreamArgsFromContext(ctx context.Context) (args StreamArgs) { + val := ctx.Value(StreamCtxKey{}) + if val == nil { + return nil + } + args, _ = val.(StreamArgs) + return args +} + +type StreamArgs interface { + Stream() Stream +} + +func AsStream(args interface{}) (Stream, error) { + s, ok := args.(StreamArgs) + if !ok { + return nil, errors.New("asStream expects StreamArgs") + } + return s.Stream(), nil +} + +type MutableStreamArgs interface { + SetStream(st Stream) +} + +func AsMutableStreamArgs(args StreamArgs) MutableStreamArgs { + margs, ok := args.(MutableStreamArgs) + if !ok { + return nil + } + return margs +} + +type streamArgs struct { + stream Stream +} + +func (s *streamArgs) Stream() Stream { + return s.stream +} + +func (s *streamArgs) SetStream(st Stream) { + s.stream = st +} + +func NewStreamArgs(stream Stream) StreamArgs { + return &streamArgs{stream: stream} +} + +type StreamReqArgs interface { + Req() any + SetReq(req any) +} + +type StreamResArgs interface { + Res() any + SetRes(res any) +} + +func NewStreamReqArgs(req any) StreamReqArgs { + return &streamReqArgs{req: req} +} + +type streamReqArgs struct { + req any +} + +func (s *streamReqArgs) Req() any { + return s.req +} + +func (s *streamReqArgs) SetReq(req any) { + s.req = req +} + +func NewStreamResArgs(res any) StreamResArgs { + return &streamResArgs{res: res} +} + +type streamResArgs struct { + res any +} + +func (s *streamResArgs) Res() any { + return s.res +} + +func (s *streamResArgs) SetRes(res any) { + s.res = res +} diff --git a/pkg/streamx/stream_middleware.go b/pkg/streamx/stream_middleware.go new file mode 100644 index 0000000000..30ab1548b1 --- /dev/null +++ b/pkg/streamx/stream_middleware.go @@ -0,0 +1,48 @@ +package streamx + +import ( + "context" +) + +type StreamHandler struct { + Handler any + StreamMiddleware StreamMiddleware + StreamRecvMiddleware StreamRecvMiddleware + StreamSendMiddleware StreamSendMiddleware +} + +type StreamEndpoint func(ctx context.Context, streamArgs StreamArgs, reqArgs StreamReqArgs, resArgs StreamResArgs) (err error) +type StreamMiddleware func(next StreamEndpoint) StreamEndpoint + +type StreamRecvEndpoint func(ctx context.Context, stream Stream, res any) (err error) +type StreamSendEndpoint func(ctx context.Context, stream Stream, req any) (err error) + +type StreamRecvMiddleware func(next StreamRecvEndpoint) StreamRecvEndpoint +type StreamSendMiddleware func(next StreamSendEndpoint) StreamSendEndpoint + +func StreamMiddlewareChain(mws ...StreamMiddleware) StreamMiddleware { + return func(next StreamEndpoint) StreamEndpoint { + for i := len(mws) - 1; i >= 0; i-- { + next = mws[i](next) + } + return next + } +} + +func StreamRecvMiddlewareChain(mws ...StreamRecvMiddleware) StreamRecvMiddleware { + return func(next StreamRecvEndpoint) StreamRecvEndpoint { + for i := len(mws) - 1; i >= 0; i-- { + next = mws[i](next) + } + return next + } +} + +func StreamSendMiddlewareChain(mws ...StreamSendMiddleware) StreamSendMiddleware { + return func(next StreamSendEndpoint) StreamSendEndpoint { + for i := len(mws) - 1; i >= 0; i-- { + next = mws[i](next) + } + return next + } +} diff --git a/pkg/utils/contextmap/contextmap.go b/pkg/utils/contextmap/contextmap.go index d002a8bd9e..bf30fc1463 100644 --- a/pkg/utils/contextmap/contextmap.go +++ b/pkg/utils/contextmap/contextmap.go @@ -25,7 +25,7 @@ type contextMapKey struct{} // WithContextMap returns a new context that carries a sync.Map // It's useful if you want to share a sync.Map between middlewares, especially for -// Middleware and RecvMiddleware/SendMiddleware, since in recv/send middlewares, +// StreamMiddleware and RecvMiddleware/SendMiddleware, since in recv/send middlewares, // we can only get the stream.Context() which is a fixed node in the context tree. // // Note: it's not added to context by default, and you should add it yourself if needed. diff --git a/server/option_advanced_test.go b/server/option_advanced_test.go index 66ba46de80..4b99bbecd0 100644 --- a/server/option_advanced_test.go +++ b/server/option_advanced_test.go @@ -19,7 +19,6 @@ package server import ( "context" "errors" - "reflect" "testing" "time" @@ -232,12 +231,6 @@ func TestWithSupportedTransportsFunc(t *testing.T) { }, wantTransports: []string{"ttheader_mux"}, }, - { - options: []Option{ - WithTransHandlerFactory(nil), - }, - wantTransports: nil, - }, } var svr Server for _, tcase := range cases { @@ -246,7 +239,7 @@ func TestWithSupportedTransportsFunc(t *testing.T) { svr.RegisterService(svcInfo, new(mockImpl)) svr.(*server).fillMoreServiceInfo(nil) svcInfo = svr.(*server).svcs.SearchService(svcInfo.ServiceName, mocks.MockMethod, false) - test.Assert(t, reflect.DeepEqual(svcInfo.Extra["transports"], tcase.wantTransports)) + test.DeepEqual(t, svcInfo.Extra["transports"], tcase.wantTransports) } } diff --git a/server/option_test.go b/server/option_test.go index cc0e25eb24..eaf90e7b59 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -32,10 +32,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" "github.com/cloudwego/kitex/pkg/remote/codec/thrift" - "github.com/cloudwego/kitex/pkg/remote/trans/detection" - "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/remote/trans/netpollmux" - "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -257,8 +254,6 @@ func TestMuxTransportOption(t *testing.T) { goWaitAndStop(t, svr1) err = svr1.Run() test.Assert(t, err == nil, err) - iSvr1 := svr1.(*server) - test.DeepEqual(t, iSvr1.opt.RemoteOpt.SvrHandlerFactory, detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory())) svr2, _ := NewTestServer(WithMuxTransport()) err = svr2.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) diff --git a/server/server.go b/server/server.go index 56d454e67b..82116c397f 100644 --- a/server/server.go +++ b/server/server.go @@ -27,6 +27,11 @@ import ( "sync" "time" + "github.com/cloudwego/kitex/pkg/remote/trans/detection" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/localsession/backup" internal_server "github.com/cloudwego/kitex/internal/server" @@ -42,6 +47,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/bound" "github.com/cloudwego/kitex/pkg/remote/remotesvr" + streamxstrans "github.com/cloudwego/kitex/pkg/remote/trans/streamx" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" @@ -322,7 +328,7 @@ func (s *server) Stop() (err error) { func (s *server) buildServiceMiddleware() endpoint.Middleware { hasServiceMW := false for _, svc := range s.svcs.svcMap { - if svc.mw != nil { + if svc.MW != nil { hasServiceMW = true break } @@ -335,8 +341,8 @@ func (s *server) buildServiceMiddleware() endpoint.Middleware { ri := rpcinfo.GetRPCInfo(ctx) serviceName := ri.Invocation().ServiceName() svc := s.svcs.svcMap[serviceName] - if svc != nil && svc.mw != nil { - next = svc.mw(next) + if svc != nil && svc.MW != nil { + next = svc.MW(next) } return next(ctx, req, resp) } @@ -395,11 +401,22 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint { // clear session backup.ClearCtx() }() - implHandlerFunc := svcInfo.MethodInfo(methodName).Handler() + minfo := svcInfo.MethodInfo(methodName) + implHandlerFunc := minfo.Handler() rpcinfo.Record(ctx, ri, stats.ServerHandleStart, nil) // set session backup.BackupCtx(ctx) - err = implHandlerFunc(ctx, svc.handler, args, resp) + + handler := svc.handler + if minfo.IsStreaming() { + handler = streamx.StreamHandler{ + Handler: svc.handler, + StreamMiddleware: svc.SMW, + StreamRecvMiddleware: svc.SRecvMW, + StreamSendMiddleware: svc.SSendMW, + } + } + err = implHandlerFunc(ctx, handler, args, resp) if err != nil { if bizErr, ok := kerrors.FromBizStatusError(err); ok { if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { @@ -538,6 +555,23 @@ func doAddBoundHandler(h remote.BoundHandler, opt *remote.ServerOption) { func (s *server) newSvrTransHandler() (handler remote.ServerTransHandler, err error) { transHdlrFactory := s.opt.RemoteOpt.SvrHandlerFactory + if transHdlrFactory == nil { + candidateFactories := make([]remote.ServerTransHandlerFactory, 0) + for _, svc := range s.svcs.svcMap { + if svc.streamingProvider != nil { + candidateFactories = append(candidateFactories, + streamxstrans.NewSvrTransHandlerFactory(svc.streamingProvider), + ) + } + } + candidateFactories = append(candidateFactories, + nphttp2.NewSvrTransHandlerFactory(), + ) + transHdlrFactory = detection.NewSvrTransHandlerFactory( + netpoll.NewSvrTransHandlerFactory(), + candidateFactories..., + ) + } transHdlr, err := transHdlrFactory.NewTransHandler(s.opt.RemoteOpt) if err != nil { return nil, err diff --git a/server/server_test.go b/server/server_test.go index 42a7734e65..14f2500501 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1288,8 +1288,8 @@ func newStreamingServer(svcName string, args []streamingMethodArg, mws []endpoin svcs: &services{ svcMap: map[string]*service{ svcName: { - svcInfo: svcInfo, - mw: endpoint.Chain(mws...), + svcInfo: svcInfo, + serviceMiddlewares: serviceMiddlewares{MW: endpoint.Chain(mws...)}, }, }, }, diff --git a/server/service.go b/server/service.go index 448a2d0230..6dbc302f62 100644 --- a/server/service.go +++ b/server/service.go @@ -21,18 +21,26 @@ import ( "fmt" "github.com/cloudwego/kitex/pkg/endpoint" - "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" ) +type serviceMiddlewares struct { + MW endpoint.Middleware + SMW streamx.StreamMiddleware + SRecvMW streamx.StreamRecvMiddleware + SSendMW streamx.StreamSendMiddleware +} + type service struct { - svcInfo *serviceinfo.ServiceInfo - handler interface{} - mw endpoint.Middleware + svcInfo *serviceinfo.ServiceInfo + handler interface{} + streamingProvider streamx.ServerProvider + serviceMiddlewares } -func newService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, mw endpoint.Middleware) *service { - return &service{svcInfo: svcInfo, handler: handler, mw: mw} +func newService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, provider streamx.ServerProvider, smw serviceMiddlewares) *service { + return &service{svcInfo: svcInfo, handler: handler, streamingProvider: provider, serviceMiddlewares: smw} } type services struct { @@ -51,11 +59,25 @@ func newServices() *services { } func (s *services) addService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, registerOpts *RegisterOptions) error { - var serviceMW endpoint.Middleware + // prepare service provider + serviceProvider := registerOpts.Provider + + // prepare serviceMiddlewares + var serviceMWs serviceMiddlewares if len(registerOpts.Middlewares) > 0 { - serviceMW = endpoint.Chain(registerOpts.Middlewares...) + serviceMWs.MW = endpoint.Chain(registerOpts.Middlewares...) } - svc := newService(svcInfo, handler, serviceMW) + if len(registerOpts.StreamMiddlewares) > 0 { + serviceMWs.SMW = streamx.StreamMiddlewareChain(registerOpts.StreamMiddlewares...) + } + if len(registerOpts.StreamRecvMiddlewares) > 0 { + serviceMWs.SRecvMW = streamx.StreamRecvMiddlewareChain(registerOpts.StreamRecvMiddlewares...) + } + if len(registerOpts.StreamSendMiddlewares) > 0 { + serviceMWs.SSendMW = streamx.StreamSendMiddlewareChain(registerOpts.StreamSendMiddlewares...) + } + + svc := newService(svcInfo, handler, serviceProvider, serviceMWs) if registerOpts.IsFallbackService { if s.fallbackSvc != nil { return fmt.Errorf("multiple fallback services cannot be registered. [%s] is already registered as a fallback service", s.fallbackSvc.svcInfo.ServiceName) diff --git a/server/stream.go b/server/stream.go index d171fe5883..1043c69ff6 100644 --- a/server/stream.go +++ b/server/stream.go @@ -24,6 +24,7 @@ import ( ) func (s *server) initStreamMiddlewares(ctx context.Context) { + // for old version streaming s.opt.Streaming.EventHandler = s.opt.TracerCtl.GetStreamEventHandler() s.opt.Streaming.InitMiddlewares(ctx) } diff --git a/server/streamxserver/server.go b/server/streamxserver/server.go new file mode 100644 index 0000000000..14b9fb68b5 --- /dev/null +++ b/server/streamxserver/server.go @@ -0,0 +1,16 @@ +package streamxserver + +import ( + "github.com/cloudwego/kitex/server" +) + +type Server = server.Server + +func NewServer(opts ...Option) server.Server { + iopts := make([]server.Option, 0, len(opts)+1) + for _, opt := range opts { + iopts = append(iopts, convertServerOption(opt)) + } + s := server.NewServer(iopts...) + return s +} diff --git a/server/streamxserver/server_gen.go b/server/streamxserver/server_gen.go new file mode 100644 index 0000000000..7d5ed074bf --- /dev/null +++ b/server/streamxserver/server_gen.go @@ -0,0 +1,100 @@ +package streamxserver + +import ( + "context" + "errors" + "reflect" + + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" +) + +func InvokeStream[Header, Trailer, Req, Res any]( + ctx context.Context, smode serviceinfo.StreamingMode, + handler any, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + // prepare args + sArgs := streamx.GetStreamArgsFromContext(ctx) + if sArgs == nil { + return errors.New("server stream is nil") + } + shandler := handler.(streamx.StreamHandler) + gs := streamx.NewGenericServerStream[Header, Trailer, Req, Res](sArgs.Stream()) + gs.SetStreamRecvMiddleware(shandler.StreamRecvMiddleware) + gs.SetStreamSendMiddleware(shandler.StreamSendMiddleware) + + // before handler + var req *Req + var res *Res + switch smode { + case serviceinfo.StreamingUnary, serviceinfo.StreamingServer: + req, err = gs.Recv(ctx) + if err != nil { + return err + } + reqArgs.SetReq(req) + default: + } + + // handler call + // TODO: cache handler + rhandler := reflect.ValueOf(shandler.Handler) + mhandler := rhandler.MethodByName(sArgs.Stream().Method()) + streamInvoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + switch smode { + case serviceinfo.StreamingUnary: + called := mhandler.Call([]reflect.Value{ + reflect.ValueOf(ctx), + reflect.ValueOf(req), + }) + _res, _err := called[0].Interface(), called[1].Interface() + if _err != nil { + return _err.(error) + } + res = _res.(*Res) + if err = gs.SendAndClose(ctx, res); err != nil { + return err + } + resArgs.SetRes(res) + case serviceinfo.StreamingClient: + called := mhandler.Call([]reflect.Value{ + reflect.ValueOf(ctx), + reflect.ValueOf(gs), + }) + _res, _err := called[0].Interface(), called[1].Interface() + if _err != nil { + return _err.(error) + } + res = _res.(*Res) + if err = gs.Send(ctx, res); err != nil { + return err + } + resArgs.SetRes(res) + case serviceinfo.StreamingServer: + called := mhandler.Call([]reflect.Value{ + reflect.ValueOf(ctx), + reflect.ValueOf(req), + reflect.ValueOf(gs), + }) + _err := called[0].Interface() + if _err != nil { + return _err.(error) + } + case serviceinfo.StreamingBidirectional: + called := mhandler.Call([]reflect.Value{ + reflect.ValueOf(ctx), + reflect.ValueOf(gs), + }) + _err := called[0].Interface() + if _err != nil { + return _err.(error) + } + } + return nil + } + if shandler.StreamMiddleware != nil { + err = shandler.StreamMiddleware(streamInvoke)(ctx, sArgs, reqArgs, resArgs) + } else { + err = streamInvoke(ctx, sArgs, reqArgs, resArgs) + } + return err +} diff --git a/server/streamxserver/server_option.go b/server/streamxserver/server_option.go new file mode 100644 index 0000000000..ca4558d6a5 --- /dev/null +++ b/server/streamxserver/server_option.go @@ -0,0 +1,48 @@ +package streamxserver + +import ( + "net" + + internal_server "github.com/cloudwego/kitex/internal/server" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/server" +) + +type Option internal_server.Option +type Options = internal_server.Options + +func WithListener(ln net.Listener) Option { + return convertInternalServerOption(server.WithListener(ln)) +} + +func WithStreamMiddleware(mw streamx.StreamMiddleware) server.RegisterOption { + return server.RegisterOption{F: func(o *internal_server.RegisterOptions) { + o.StreamMiddlewares = append(o.StreamMiddlewares, mw) + }} +} + +func WithStreamRecvMiddleware(mw streamx.StreamRecvMiddleware) server.RegisterOption { + return server.RegisterOption{F: func(o *internal_server.RegisterOptions) { + o.StreamRecvMiddlewares = append(o.StreamRecvMiddlewares, mw) + }} +} + +func WithStreamSendMiddleware(mw streamx.StreamSendMiddleware) server.RegisterOption { + return server.RegisterOption{F: func(o *internal_server.RegisterOptions) { + o.StreamSendMiddlewares = append(o.StreamSendMiddlewares, mw) + }} +} + +func WithProvider(provider streamx.ServerProvider) server.RegisterOption { + return server.RegisterOption{F: func(o *internal_server.RegisterOptions) { + o.Provider = provider + }} +} + +func convertInternalServerOption(o internal_server.Option) Option { + return Option{F: o.F} +} + +func convertServerOption(o Option) internal_server.Option { + return internal_server.Option{F: o.F} +} diff --git a/transport/keys.go b/transport/keys.go index f63080eaed..93032fae9c 100644 --- a/transport/keys.go +++ b/transport/keys.go @@ -29,6 +29,7 @@ const ( HTTP GRPC HESSIAN2 + JSONRPC TTHeaderFramed = TTHeader | Framed ) From cec20d769d22c7ba501832d337b1e5dbfad70f77 Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Thu, 12 Sep 2024 15:24:42 +0800 Subject: [PATCH 02/34] feat: support generating ttheader stream v2 code (#1546) --- tool/cmd/kitex/args/args.go | 20 ++ tool/internal_pkg/generator/generator.go | 179 +++++++++++++++--- tool/internal_pkg/generator/generator_test.go | 2 +- tool/internal_pkg/generator/type.go | 10 + .../pluginmode/thriftgo/convertor.go | 10 + tool/internal_pkg/tpl/streamx/client.go | 91 +++++++++ .../tpl/streamx/handler.method.go | 21 ++ tool/internal_pkg/tpl/streamx/server.go | 45 +++++ tool/internal_pkg/tpl/streamx/service.go | 51 +++++ tool/internal_pkg/util/util.go | 7 + 10 files changed, 404 insertions(+), 32 deletions(-) create mode 100644 tool/internal_pkg/tpl/streamx/client.go create mode 100644 tool/internal_pkg/tpl/streamx/handler.method.go create mode 100644 tool/internal_pkg/tpl/streamx/server.go create mode 100644 tool/internal_pkg/tpl/streamx/service.go diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index db76a15a62..0198a29420 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -33,6 +33,7 @@ import ( "github.com/cloudwego/kitex/tool/internal_pkg/pluginmode/protoc" "github.com/cloudwego/kitex/tool/internal_pkg/pluginmode/thriftgo" "github.com/cloudwego/kitex/tool/internal_pkg/util" + "github.com/cloudwego/kitex/transport" ) // EnvPluginMode is an environment that kitex uses to distinguish run modes. @@ -131,6 +132,9 @@ func (a *Arguments) buildFlags(version string) *flag.FlagSet { f.BoolVar(&a.LocalThriftgo, "local_thriftgo", false, "Use local thriftgo exec instead of kitex embedded thriftgo.") f.Var(&a.BuiltinTpl, "tpl", "Specify kitex built-in template.") + f.BoolVar(&a.StreamX, "streamx", false, + "Generate streaming code with streamx interface", + ) f.BoolVar(&a.GenFrugal, "gen_frugal", false, `Gen frugal codec for those structs with (go.codec="frugal")`) f.Var(&a.FrugalStruct, "frugal_struct", "Gen frugal codec for given struct") @@ -190,6 +194,10 @@ func (a *Arguments) ParseArgs(version, curpath string, kitexArgs []string) (err if err != nil { return err } + err = a.checkStreamX() + if err != nil { + return err + } // todo finish protobuf if a.IDLType != "thrift" { a.GenPath = generator.KitexGenPath @@ -237,6 +245,18 @@ func (a *Arguments) checkServiceName() error { return nil } +func (a *Arguments) checkStreamX() error { + if !a.StreamX { + return nil + } + if a.IDLType == "thrift" { + // set TTHeader Streaming by default + a.Protocol = transport.TTHeader.String() + } + // todo: process pb and gRPC + return nil +} + func (a *Arguments) checkPath(curpath string) error { pathToGo, err := exec.LookPath("go") if err != nil { diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 88cee25ecc..6a82709819 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/tpl" + "github.com/cloudwego/kitex/tool/internal_pkg/tpl/streamx" "github.com/cloudwego/kitex/tool/internal_pkg/util" "github.com/cloudwego/kitex/transport" ) @@ -50,6 +51,8 @@ const ( // built in tpls MultipleServicesTpl = "multiple_services" + + streamxTTHeaderRef = "ttstream" ) var ( @@ -150,6 +153,7 @@ type Config struct { FrugalStruct util.StringSlice BuiltinTpl util.StringSlice // specify the built-in template to use + StreamX bool } // Pack packs the Config into a slice of "key=val" strings. @@ -299,6 +303,9 @@ func (c *Config) ApplyExtension() error { } func (c *Config) IsUsingMultipleServicesTpl() bool { + if c.StreamX { + return true + } for _, part := range c.BuiltinTpl { if part == MultipleServicesTpl { return true @@ -429,10 +436,19 @@ func (g *generator) generateHandler(pkg *PackageInfo, svc *ServiceInfo, handlerF return f, nil } - task := Task{ - Name: HandlerFileName, - Path: handlerFilePath, - Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl, + var task Task + if g.StreamX && svc.HasStreaming { + task = Task{ + Name: HandlerFileName, + Path: handlerFilePath, + Text: tpl.HandlerTpl + "\n" + streamx.HandlerMethodsTpl, + } + } else { + task = Task{ + Name: HandlerFileName, + Path: handlerFilePath, + Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl, + } } g.setImports(task.Name, pkg) handle := func(task *Task, pkg *PackageInfo) (*File, error) { @@ -455,25 +471,29 @@ func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) { ext = new(TemplateExtension) } - tasks := []*Task{ - { - Name: ClientFileName, - Path: util.JoinPath(output, ClientFileName), - Text: tpl.ClientTpl, - Ext: ext.ExtendClient, - }, - { - Name: ServerFileName, - Path: util.JoinPath(output, ServerFileName), - Text: tpl.ServerTpl, - Ext: ext.ExtendServer, - }, - { - Name: ServiceFileName, - Path: util.JoinPath(output, svcPkg+".go"), - Text: tpl.ServiceTpl, - }, + cliTask := &Task{ + Name: ClientFileName, + Path: util.JoinPath(output, ClientFileName), + Text: tpl.ClientTpl, + Ext: ext.ExtendClient, + } + svrTask := &Task{ + Name: ServerFileName, + Path: util.JoinPath(output, ServerFileName), + Text: tpl.ServerTpl, + Ext: ext.ExtendServer, + } + svcTask := &Task{ + Name: ServiceFileName, + Path: util.JoinPath(output, svcPkg+".go"), + Text: tpl.ServiceTpl, } + if g.StreamX && pkg.ServiceInfo.HasStreaming { + cliTask.Text = streamx.ClientTpl + svrTask.Text = streamx.ServerTpl + svcTask.Text = streamx.ServiceTpl + } + tasks := []*Task{cliTask, svrTask, svcTask} // do not generate invoker.go in service package by default if g.Config.GenerateInvoker { @@ -524,6 +544,9 @@ func (g *generator) updatePackageInfo(pkg *PackageInfo) { if strings.EqualFold(g.Protocol, transport.HESSIAN2.String()) { pkg.Protocol = transport.HESSIAN2 } + if strings.EqualFold(g.Protocol, transport.TTHeader.String()) { + pkg.Protocol = transport.TTHeader + } if pkg.Dependencies == nil { pkg.Dependencies = make(map[string]string) } @@ -539,19 +562,27 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { pkg.Imports = make(map[string]map[string]bool) switch name { case ClientFileName: - pkg.AddImports("client") - if pkg.HasStreaming { - pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming") - pkg.AddImport("transport", "github.com/cloudwego/kitex/transport") - } - if len(pkg.AllMethods()) > 0 { - if needCallOpt(pkg) { - pkg.AddImports("callopt") + if g.StreamX && pkg.HasStreaming { + g.setStreamXClientImports(pkg) + } else { + pkg.AddImports("client") + if pkg.HasStreaming { + pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming") + pkg.AddImport("transport", "github.com/cloudwego/kitex/transport") + } + if len(pkg.AllMethods()) > 0 { + if needCallOpt(pkg) { + pkg.AddImports("callopt") + } + pkg.AddImports("context") } - pkg.AddImports("context") } fallthrough case HandlerFileName: + if g.StreamX && pkg.HasStreaming { + g.setStreamXHandlerImports(pkg) + return + } for _, m := range pkg.ServiceInfo.AllMethods() { if !m.ServerStreaming && !m.ClientStreaming { pkg.AddImports("context") @@ -568,11 +599,19 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { } } case ServerFileName, InvokerFileName: + if g.StreamX && pkg.HasStreaming { + g.setStreamXServerImports(pkg) + return + } if len(pkg.CombineServices) == 0 { pkg.AddImport(pkg.ServiceInfo.PkgRefName, pkg.ServiceInfo.ImportPath) } pkg.AddImports("server") case ServiceFileName: + if g.StreamX && pkg.HasStreaming { + g.setStreamXServiceImports(pkg) + return + } pkg.AddImports("errors") pkg.AddImports("client") pkg.AddImport("kitex", "github.com/cloudwego/kitex/pkg/serviceinfo") @@ -646,3 +685,81 @@ func needCallOpt(pkg *PackageInfo) bool { } return needCallOpt } + +func (g *generator) setStreamXClientImports(pkg *PackageInfo) { + pkg.AddImports("client") + pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient") + if len(pkg.AllMethods()) > 0 { + pkg.AddImports("context") + pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient/streamxcallopt") + pkg.AddImports("github.com/cloudwego/kitex/pkg/serviceinfo") + } + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") + if g.IDLType == "thrift" { + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef) + } +} + +func (g *generator) setStreamXServerImports(pkg *PackageInfo) { + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") + pkg.AddImports("server") + pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver") + if g.IDLType == "thrift" { + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef) + } + for _, m := range pkg.AllMethods() { + pkg.AddImports("context") + for _, a := range m.Args { + for _, dep := range a.Deps { + pkg.AddImport(dep.PkgRefName, dep.ImportPath) + } + } + if !m.Void && m.Resp != nil { + for _, dep := range m.Resp.Deps { + pkg.AddImport(dep.PkgRefName, dep.ImportPath) + } + } + } +} + +func (g *generator) setStreamXServiceImports(pkg *PackageInfo) { + pkg.AddImports("github.com/cloudwego/kitex/pkg/serviceinfo") + for _, m := range pkg.AllMethods() { + pkg.AddImports("context") + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") + pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver") + if g.IDLType == "thrift" { + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef) + } + for _, a := range m.Args { + for _, dep := range a.Deps { + pkg.AddImport(dep.PkgRefName, dep.ImportPath) + } + } + if !m.Void && m.Resp != nil { + for _, dep := range m.Resp.Deps { + pkg.AddImport(dep.PkgRefName, dep.ImportPath) + } + } + } +} + +func (g *generator) setStreamXHandlerImports(pkg *PackageInfo) { + for _, m := range pkg.ServiceInfo.AllMethods() { + pkg.AddImports("context") + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") + if g.IDLType == "thrift" { + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef) + } + for _, a := range m.Args { + for _, dep := range a.Deps { + pkg.AddImport(dep.PkgRefName, dep.ImportPath) + } + } + if !m.Void && m.Resp != nil { + for _, dep := range m.Resp.Deps { + pkg.AddImport(dep.PkgRefName, dep.ImportPath) + } + } + } +} diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index 63c4e9d902..d34e423d67 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -69,7 +69,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false", "LocalThriftgo=false", "GenFrugal=false", "FrugalStruct=", "BuiltinTpl="}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false", "LocalThriftgo=false", "GenFrugal=false", "FrugalStruct=", "BuiltinTpl=", "StreamX=false"}, }, } for _, tt := range tests { diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index bdbb92ebf3..efdae50ef9 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -232,6 +232,7 @@ var funcs = map[string]interface{}{ "HasFeature": HasFeature, "FilterImports": FilterImports, "backquoted": BackQuoted, + "getStreamxRef": ToStreamxRef, } func AddTemplateFunc(key string, f interface{}) { @@ -395,3 +396,12 @@ func FilterImports(Imports map[string]map[string]bool, ms []*MethodInfo) map[str func BackQuoted(s string) string { return "`" + s + "`" } + +func ToStreamxRef(protocol transport.Protocol) string { + switch protocol { + case transport.TTHeader: + return streamxTTHeaderRef + default: + return "" + } +} diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor.go b/tool/internal_pkg/pluginmode/thriftgo/convertor.go index e7d6e9e3ea..ff5498afb3 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/convertor.go +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor.go @@ -388,6 +388,9 @@ func (c *converter) convertTypes(req *plugin.Request) error { if c.IsHessian2() { si.Protocol = transport.HESSIAN2.String() } + if c.IsTTHeader() { + si.Protocol = transport.TTHeader.String() + } si.HandlerReturnKeepResp = c.Config.HandlerReturnKeepResp si.UseThriftReflection = c.Utils.Features().WithReflection @@ -435,6 +438,9 @@ func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (*ge if c.IsHessian2() { si.Protocol = transport.HESSIAN2.String() } + if c.IsTTHeader() { + si.Protocol = transport.TTHeader.String() + } si.HandlerReturnKeepResp = c.Config.HandlerReturnKeepResp si.UseThriftReflection = c.Utils.Features().WithReflection return si, nil @@ -533,6 +539,10 @@ func (c *converter) IsHessian2() bool { return strings.EqualFold(c.Config.Protocol, transport.HESSIAN2.String()) } +func (c *converter) IsTTHeader() bool { + return strings.EqualFold(c.Config.Protocol, transport.TTHeader.String()) +} + func (c *converter) copyAnnotations(annotations parser.Annotations) parser.Annotations { copied := make(parser.Annotations, 0, len(annotations)) for _, annotation := range annotations { diff --git a/tool/internal_pkg/tpl/streamx/client.go b/tool/internal_pkg/tpl/streamx/client.go new file mode 100644 index 0000000000..cf95aded81 --- /dev/null +++ b/tool/internal_pkg/tpl/streamx/client.go @@ -0,0 +1,91 @@ +package streamx + +var ClientTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. + +package {{ToLower .ServiceName}} + +import ( + {{- range $path, $aliases := .Imports}} + {{- if not $aliases}} + "{{$path}}" + {{- else}} + {{- range $alias, $is := $aliases}} + {{$alias}} "{{$path}}" + {{- end}} + {{- end}} + {{- end}} +) +{{- $protocol := .Protocol | getStreamxRef}} + +type Client interface { +{{- range .AllMethods}} +{{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} +{{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} +{{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} +{{- $bidiSide := and .ClientStreaming .ServerStreaming}} +{{- $arg := index .Args 0}} + {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (r {{.Resp.Type}}, err error) + {{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.ClientStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (stream streamx.ServerStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr .Resp.Type}}], err error) + {{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.BidiStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) + {{- end}} +{{- end}} +} + +func NewClient(destService string, opts ...streamxclient.Option) (Client, error) { + var options []streamxclient.Option + options = append(options, streamxclient.WithDestService(destService)) + options = append(options, opts...) + cp, err := {{$protocol}}.NewClientProvider(svcInfo) + if err != nil { + return nil, err + } + options = append(options, streamxclient.WithProvider(cp)) + cli, err := streamxclient.NewClient(svcInfo, options...) + if err != nil { + return nil, err + } + kc := &kClient{streamer: cli, caller: cli.(client.Client)} + return kc, nil +} + +var _ Client = (*kClient)(nil) + +type kClient struct { + caller client.Client + streamer streamxclient.Client +} + +{{- range .AllMethods}} +{{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} +{{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} +{{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} +{{- $bidiSide := and .ClientStreaming .ServerStreaming}} +{{- $mode := ""}} + {{- if $bidiSide -}} {{- $mode = "serviceinfo.StreamingBidirectional" }} + {{- else if $serverSide -}} {{- $mode = "serviceinfo.StreamingServer" }} + {{- else if $clientSide -}} {{- $mode = "serviceinfo.StreamingClient" }} + {{- else if $unary -}} {{- $mode = "serviceinfo.StreamingUnary" }} + {{- end}} +{{- $arg := index .Args 0}} +func (c *kClient) {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) ({{.Resp.Type}}, error) { + res := new({{NotPtr .Resp.Type}}) + _, err := streamxclient.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, c.streamer, {{$mode}}, "{{.RawName}}", req, res, callOptions...) + if err != nil { + return nil, err + } + return res, nil +{{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.ClientStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) { + return streamxclient.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, c.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) +{{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (stream streamx.ServerStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr .Resp.Type}}], err error) { + return streamxclient.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, c.streamer, {{$mode}}, "{{.RawName}}", req, nil, callOptions...) +{{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.BidiStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) { + return streamxclient.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, c.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) +{{- end}} +} +{{- end}} +` diff --git a/tool/internal_pkg/tpl/streamx/handler.method.go b/tool/internal_pkg/tpl/streamx/handler.method.go new file mode 100644 index 0000000000..a46521717f --- /dev/null +++ b/tool/internal_pkg/tpl/streamx/handler.method.go @@ -0,0 +1,21 @@ +package streamx + +var HandlerMethodsTpl = `{{define "HandlerMethod"}} +{{- $protocol := .Protocol | getStreamxRef}} +{{- range .AllMethods}} +{{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} +{{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} +{{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} +{{- $bidiSide := and .ClientStreaming .ServerStreaming}} +{{- $arg := index .Args 0}} +func (s *{{.ServiceName}}Impl) {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}) (resp {{.Resp.Type}}, err error) { + {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (resp {{.Resp.Type}}, err error) { + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr .Resp.Type}}]) (err error) { + {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (err error) { + {{- end}} + // TODO: Your code here... + return +} +{{- end}} +{{end}}{{/* define "HandlerMethod" */}} +` diff --git a/tool/internal_pkg/tpl/streamx/server.go b/tool/internal_pkg/tpl/streamx/server.go new file mode 100644 index 0000000000..46daeb4667 --- /dev/null +++ b/tool/internal_pkg/tpl/streamx/server.go @@ -0,0 +1,45 @@ +package streamx + +var ServerTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. +package {{ToLower .ServiceName}} + +import ( + {{- range $path, $aliases := .Imports}} + {{- if not $aliases}} + "{{$path}}" + {{- else}} + {{- range $alias, $is := $aliases}} + {{$alias}} "{{$path}}" + {{- end}} + {{- end}} + {{- end}} +) +{{- $protocol := .Protocol | getStreamxRef}} + +type Server interface { +{{- range .AllMethods}} +{{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} +{{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} +{{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} +{{- $bidiSide := and .ClientStreaming .ServerStreaming}} +{{- $arg := index .Args 0}} + {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}) ({{.Resp.Type}}, error) + {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) ({{.Resp.Type}}, error) + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr .Resp.Type}}]) error + {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) error + {{- end}} +{{- end}} +} + +func RegisterService(svr server.Server, handler Server, opts ...server.RegisterOption) error { + sp, err := {{$protocol}}.NewServerProvider(svcInfo) + if err != nil { + return err + } + nopts := []server.RegisterOption{ + streamxserver.WithProvider(sp), + } + nopts = append(nopts, opts...) + return svr.RegisterService(svcInfo, handler, nopts...) +} +` diff --git a/tool/internal_pkg/tpl/streamx/service.go b/tool/internal_pkg/tpl/streamx/service.go new file mode 100644 index 0000000000..3c791012a2 --- /dev/null +++ b/tool/internal_pkg/tpl/streamx/service.go @@ -0,0 +1,51 @@ +package streamx + +var ServiceTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. +package {{ToLower .ServiceName}} + +import ( + {{- range $path, $aliases := .Imports}} + {{- if not $aliases}} + "{{$path}}" + {{- else}} + {{- range $alias, $is := $aliases}} + {{$alias}} "{{$path}}" + {{- end}} + {{- end}} + {{- end}} +) +{{- $protocol := .Protocol | getStreamxRef}} + +var svcInfo = &serviceinfo.ServiceInfo{ + ServiceName: "{{.RawServiceName}}", + Methods: map[string]serviceinfo.MethodInfo{ + {{- range .AllMethods}} + {{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} + {{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} + {{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} + {{- $bidiSide := and .ClientStreaming .ServerStreaming}} + {{- $arg := index .Args 0}} + {{- $mode := ""}} + {{- if $bidiSide -}} {{- $mode = "serviceinfo.StreamingBidirectional" }} + {{- else if $serverSide -}} {{- $mode = "serviceinfo.StreamingServer" }} + {{- else if $clientSide -}} {{- $mode = "serviceinfo.StreamingClient" }} + {{- else if $unary -}} {{- $mode = "serviceinfo.StreamingUnary" }} + {{- end}} + "{{.RawName}}": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, {{$mode}}, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode({{$mode}}), + ), + {{- end}} + }, + Extra: map[string]interface{}{ + "streaming": true, + }, +} + +` diff --git a/tool/internal_pkg/util/util.go b/tool/internal_pkg/util/util.go index 196f5b29bc..b2fb8b5eb3 100644 --- a/tool/internal_pkg/util/util.go +++ b/tool/internal_pkg/util/util.go @@ -336,3 +336,10 @@ func PrintlImports(imports []Import) string { } return builder.String() } + +func ToString(val interface{}) string { + if stringer, ok := val.(fmt.Stringer); ok { + return stringer.String() + } + return "" +} From cb056c0e4874ea2cfd1378e07fc28015606738e7 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Thu, 12 Sep 2024 15:28:16 +0800 Subject: [PATCH 03/34] feat: support mux transport feat: support mux transport --- .../provider/jsonrpc/jsonrpc_gen_test.go | 2 +- .../provider/ttstream/client_provier.go | 6 +- .../provider/ttstream/client_trans_pool.go | 103 +----------------- .../ttstream/client_trans_pool_longconn.go | 102 +++++++++++++++++ .../ttstream/client_trans_pool_mux.go | 50 +++++++++ pkg/streamx/provider/ttstream/scavenger.go | 42 +++++++ .../provider/ttstream/server_provider.go | 4 +- pkg/streamx/provider/ttstream/transport.go | 26 ++++- .../provider/ttstream/transport_test.go | 6 +- .../ttstream/ttstream_gen_service_test.go | 2 +- server/server.go | 2 +- 11 files changed, 232 insertions(+), 113 deletions(-) create mode 100644 pkg/streamx/provider/ttstream/client_trans_pool_longconn.go create mode 100644 pkg/streamx/provider/ttstream/client_trans_pool_mux.go create mode 100644 pkg/streamx/provider/ttstream/scavenger.go diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go index bf0bf56394..e872ecfc6f 100644 --- a/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go +++ b/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go @@ -64,7 +64,7 @@ var serviceInfo = &serviceinfo.ServiceInfo{ serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), ), }, - Extra: map[string]interface{}{"streaming": true}, + Extra: map[string]interface{}{"streaming": true, "streamx": true}, } func NewClient(destService string, opts ...streamxclient.Option) (ClientInterface, error) { diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index cf830b1248..de3a51a188 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -22,12 +22,12 @@ func NewClientProvider(sinfo *serviceinfo.ServiceInfo, opts ...ClientProviderOpt for _, opt := range opts { opt(cp) } - cp.transPool = newTransPool(sinfo) + cp.transPool = newMuxTransPool(sinfo) return cp, nil } type clientProvider struct { - transPool *transPool + transPool transPool sinfo *serviceinfo.ServiceInfo metaHandler MetaFrameHandler } @@ -59,8 +59,6 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callO runtime.SetFinalizer(cs, func(cs *clientStream) { klog.Debugf("client stream[%v] closing", cs.sid) _ = cs.close() - // TODO: currently using one conn one stream at same time - //_ = trans.close() c.transPool.Put(trans) }) return cs, err diff --git a/pkg/streamx/provider/ttstream/client_trans_pool.go b/pkg/streamx/provider/ttstream/client_trans_pool.go index f8103e73a5..57eb0bbbd4 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool.go @@ -1,105 +1,10 @@ package ttstream -import ( - "runtime" - "sync" - "time" - - "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/netpoll" -) - // TODO: it's to complex for users implement idle check // so let's implement it in netpoll -func newTransStack() *transStack { - return &transStack{} -} - -// FILO -type transStack struct { - mu sync.Mutex - stack []*transport // TODO: now it's a mem leak stack implementation - modified time.Time -} - -func (s *transStack) Pop() (trans *transport) { - s.mu.Lock() - if len(s.stack) == 0 { - s.mu.Unlock() - return nil - } - trans = s.stack[len(s.stack)-1] - s.stack = s.stack[:len(s.stack)-1] - s.mu.Unlock() - return trans -} - -func (s *transStack) Push(trans *transport) { - s.mu.Lock() - s.stack = append(s.stack, trans) - s.modified = time.Now() - s.mu.Unlock() -} - -func (s *transStack) Clear() { - s.mu.Lock() - s.stack = []*transport{} - s.modified = time.Now() - s.mu.Unlock() -} - -func newTransPool(sinfo *serviceinfo.ServiceInfo) *transPool { - tp := &transPool{sinfo: sinfo} - go func() { - now := time.Now() - deleteKeys := make([]string, 0) - tp.pool.Range(func(addr, value any) bool { - tstack := value.(*transStack) - duration := now.Sub(tstack.modified) - if duration >= time.Minute*10 { - deleteKeys = append(deleteKeys, addr.(string)) - } - return true - }) - }() - return tp -} - -type transPool struct { - pool sync.Map // {"addr":*transStack} - sinfo *serviceinfo.ServiceInfo -} - -func (c *transPool) Get(network string, addr string) (trans *transport, err error) { - var cstack *transStack - val, ok := c.pool.Load(addr) - if !ok { - // TODO: here may have a race problem - cstack = newTransStack() - _, _ = c.pool.LoadOrStore(addr, cstack) - } else { - cstack = val.(*transStack) - } - trans = cstack.Pop() - if trans != nil { - return trans, nil - } - conn, err := netpoll.DialConnection(network, addr, time.Second) - if err != nil { - return nil, err - } - trans = newTransport(clientTransport, c.sinfo, conn) - runtime.SetFinalizer(trans, func(t *transport) { t.close() }) - return trans, nil -} - -func (c *transPool) Put(trans *transport) { - var cstack *transStack - val, ok := c.pool.Load(trans.conn.RemoteAddr()) - if !ok { - return - } - cstack = val.(*transStack) - cstack.Push(trans) +// TODO: make transPoll configurable +type transPool interface { + Get(network string, addr string) (trans *transport, err error) + Put(trans *transport) } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go new file mode 100644 index 0000000000..8a54f332c7 --- /dev/null +++ b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go @@ -0,0 +1,102 @@ +package ttstream + +import ( + "runtime" + "sync" + "time" + + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/netpoll" +) + +func newTransPool(sinfo *serviceinfo.ServiceInfo) *longConnTransPool { + tp := &longConnTransPool{sinfo: sinfo} + go func() { + now := time.Now() + deleteKeys := make([]string, 0) + tp.pool.Range(func(addr, value any) bool { + tstack := value.(*transStack) + duration := now.Sub(tstack.modified) + if duration >= time.Minute*10 { + deleteKeys = append(deleteKeys, addr.(string)) + } + return true + }) + }() + return tp +} + +type longConnTransPool struct { + pool sync.Map // {"addr":*transStack} + sinfo *serviceinfo.ServiceInfo +} + +func (c *longConnTransPool) Get(network string, addr string) (trans *transport, err error) { + var cstack *transStack + val, ok := c.pool.Load(addr) + if !ok { + // TODO: here may have a race problem + cstack = newTransStack() + _, _ = c.pool.LoadOrStore(addr, cstack) + } else { + cstack = val.(*transStack) + } + trans = cstack.Pop() + if trans != nil { + return trans, nil + } + conn, err := netpoll.DialConnection(network, addr, time.Second) + if err != nil { + return nil, err + } + trans = newTransport(clientTransport, c.sinfo, conn) + runtime.SetFinalizer(trans, func(t *transport) { t.Close() }) + return trans, nil +} + +func (c *longConnTransPool) Put(trans *transport) { + var cstack *transStack + val, ok := c.pool.Load(trans.conn.RemoteAddr()) + if !ok { + return + } + cstack = val.(*transStack) + cstack.Push(trans) +} + +func newTransStack() *transStack { + return &transStack{} +} + +// FILO +type transStack struct { + mu sync.Mutex + stack []*transport // TODO: now it's a mem leak stack implementation + modified time.Time +} + +func (s *transStack) Pop() (trans *transport) { + s.mu.Lock() + if len(s.stack) == 0 { + s.mu.Unlock() + return nil + } + trans = s.stack[len(s.stack)-1] + s.stack = s.stack[:len(s.stack)-1] + s.mu.Unlock() + return trans +} + +func (s *transStack) Push(trans *transport) { + s.mu.Lock() + s.stack = append(s.stack, trans) + s.modified = time.Now() + s.mu.Unlock() +} + +func (s *transStack) Clear() { + s.mu.Lock() + s.stack = []*transport{} + s.modified = time.Now() + s.mu.Unlock() +} diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go new file mode 100644 index 0000000000..b08677a274 --- /dev/null +++ b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go @@ -0,0 +1,50 @@ +package ttstream + +import ( + "sync" + "time" + + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/netpoll" + "golang.org/x/sync/singleflight" +) + +var _ transPool = (*muxTransPool)(nil) + +func newMuxTransPool(sinfo *serviceinfo.ServiceInfo) transPool { + t := new(muxTransPool) + t.sinfo = sinfo + t.scavenger = newScavenger() + return t +} + +type muxTransPool struct { + sinfo *serviceinfo.ServiceInfo + pool sync.Map // addr:netpoll.Connection + scavenger *scavenger + sflight singleflight.Group +} + +func (m *muxTransPool) Get(network string, addr string) (trans *transport, err error) { + v, ok := m.pool.Load(addr) + if ok { + return v.(*transport), nil + } + v, err, _ = m.sflight.Do(addr, func() (interface{}, error) { + conn, err := netpoll.DialConnection(network, addr, time.Second) + if err != nil { + return nil, err + } + trans = newTransport(clientTransport, m.sinfo, conn) + m.scavenger.Add(trans) + return trans, nil + }) + if err != nil { + return nil, err + } + return v.(*transport), nil +} + +func (m *muxTransPool) Put(trans *transport) { + // do nothing +} diff --git a/pkg/streamx/provider/ttstream/scavenger.go b/pkg/streamx/provider/ttstream/scavenger.go new file mode 100644 index 0000000000..a809627448 --- /dev/null +++ b/pkg/streamx/provider/ttstream/scavenger.go @@ -0,0 +1,42 @@ +package ttstream + +import ( + "sync" + "time" +) + +type Object interface { + Available() bool + Close() error +} + +func newScavenger() *scavenger { + s := new(scavenger) + go s.Cleaning() + return s +} + +type scavenger struct { + sync.RWMutex + objects []Object +} + +func (s *scavenger) Add(o Object) { + s.Lock() + s.objects = append(s.objects, o) + s.Unlock() +} + +func (s *scavenger) Cleaning() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for range ticker.C { + s.RLock() + for _, o := range s.objects { + if !o.Available() { + _ = o.Close() + } + } + s.RUnlock() + } +} diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go index 7a8c490bf0..57740e7e0b 100644 --- a/pkg/streamx/provider/ttstream/server_provider.go +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -47,8 +47,8 @@ func (s serverProvider) OnInactive(ctx context.Context, conn net.Conn) (context. if trans == nil { return ctx, nil } - // server should close transport - err := trans.close() + // server should Close transport + err := trans.Close() if err != nil { return nil, err } diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index 650e8d24dc..f35a1c08f2 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -20,6 +20,8 @@ const ( serverTransport int32 = 2 ) +var _ Object = (*transport)(nil) + type transport struct { kind int32 sinfo *serviceinfo.ServiceInfo @@ -30,6 +32,9 @@ type transport struct { sch chan *stream // in-coming stream channel wch chan Frame // out-coming frame channel stop chan struct{} + + // for scavenger check + lastActive atomic.Value // time.Time } func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection) *transport { @@ -78,6 +83,9 @@ func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { func (t *transport) loopRead() error { for { + now := time.Now() + t.lastActive.Store(now) + // decode frame fr, err := DecodeFrame(context.Background(), t.reader) if err != nil { @@ -121,7 +129,7 @@ func (t *transport) loopRead() error { } sio.input(fr) case trailerFrameType: - // Trailer Frame: recv trailer, close read direction + // Trailer Frame: recv trailer, Close read direction sio, ok := t.loadStreamIO(fr.sid) if !ok { return fmt.Errorf("transport[%d] read a unknown stream trailer: sid=%d", t.kind, fr.sid) @@ -140,6 +148,9 @@ func (t *transport) writeFrame(frame Frame) error { func (t *transport) loopWrite() error { for { + now := time.Now() + t.lastActive.Store(now) + select { case <-t.stop: // re-check wch queue @@ -159,7 +170,17 @@ func (t *transport) loopWrite() error { } } -func (t *transport) close() (err error) { +func (t *transport) Available() bool { + v := t.lastActive.Load() + if v == nil { + return true + } + lastActive := v.(time.Time) + // let unavailable time configurable + return time.Now().Sub(lastActive) < time.Minute*10 +} + +func (t *transport) Close() (err error) { select { case <-t.stop: default: @@ -225,6 +246,7 @@ var clientStreamID int32 // newStream create new stream on current connection // it's typically used by client side +// newStream is concurrency safe func (t *transport) newStream( ctx context.Context, method string, header map[string]string) (*stream, error) { if t.kind != clientTransport { diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index 45c6848a45..f03b4b3741 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -159,7 +159,7 @@ func TestTransport(t *testing.T) { // send trailer(trailer is stored in ctx) err = cs.CloseSend(ctx) test.Assert(t, err == nil, err) - t.Logf("client stream[%d] close send", sid) + t.Logf("client stream[%d] Close send", sid) // recv trailer tl, err := cs.Trailer() @@ -174,8 +174,8 @@ func TestTransport(t *testing.T) { time.Sleep(time.Millisecond * 10) } - // close conn - err = trans.close() + // Close conn + err = trans.Close() test.Assert(t, err == nil, err) err = ln.Close() test.Assert(t, err == nil, err) diff --git a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go b/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go index c47a464620..1e777b4988 100644 --- a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go @@ -94,7 +94,7 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), ), }, - Extra: map[string]interface{}{"streaming": true}, + Extra: map[string]interface{}{"streaming": true, "streamx": true}, } // --- Define RegisterService interface --- diff --git a/server/server.go b/server/server.go index 82116c397f..e5cd271cb4 100644 --- a/server/server.go +++ b/server/server.go @@ -408,7 +408,7 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint { backup.BackupCtx(ctx) handler := svc.handler - if minfo.IsStreaming() { + if minfo.IsStreaming() && svcInfo.Extra["streamx"] != nil { handler = streamx.StreamHandler{ Handler: svc.handler, StreamMiddleware: svc.SMW, From 6712719c399b501112cbb1fbf85d4133aaadabc7 Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Thu, 12 Sep 2024 16:19:39 +0800 Subject: [PATCH 04/34] fix: lots of bugs fix: lots of bugs fix: header flag check on bit fix: long conn reuse conn only when two trailer send fix: set read timeout to 10mins fix: trigger stream.Header() if trailer coming refactor: stack impl fix: transport should close all streams when it close fix: stack pop chore: add object pool clean log feat: close header and trailer trigger fix: new stream should not write on recv header fix: conn leak fix: close all streams when transport close fix: close stream when recv return error feat: support recv timeout refactor: object pool fix: stack range delete refactor: long conn pool fix: close transport if there is io error fix: close conn after all goroutines return fix: long conn close conn when wchan is not empty fix: long conn close twice fix: client should send trailer feat: header frame handler fix: long conn feat: support short conn chore: disable discard conn refactor: transport scavenger in long conn mode feat: support MaxIdlePerAddress fix: server close transport when peer closed fix: writeFrame should be closed safely fix: encode release buffer if no payload feat: long conn feat: enable convert options perf: turning flush delay perf: using loop write perf: optimise write direct perf: optimise write direct fix: pipe close trigger perf: optimise pipe and queue locker perf: use stand codec utils fix: server stream should not delete it from stream map fix: mux should store trans into pool perf: rm write loop refactor: using same header and trailer (#1554) perf: using netpoll shard queue feat: ctx canceling transmit (#1552) perf: reuse frame and linknode perf: use batch read queue to replace old pipe impl perf: reuse decode payload mem feat: use replace channel to pipeline --- client/client.go | 27 +- client/client_streamx.go | 28 +- client/stream.go | 29 +- client/streamxclient/client.go | 2 +- client/streamxclient/client_gen.go | 7 +- client/streamxclient/client_option.go | 23 +- .../streamxcallopt/call_option.go | 5 +- go.sum | 7 +- internal/client/option.go | 13 +- internal/test/port.go | 2 +- pkg/remote/remotecli/stream.go | 19 +- pkg/remote/trans/streamx/server_handler.go | 20 +- pkg/rpcinfo/interface.go | 5 + pkg/rpcinfo/mutable.go | 2 + pkg/rpcinfo/rpcconfig.go | 11 + pkg/streamx/client_options.go | 10 + pkg/streamx/client_provider.go | 17 + pkg/streamx/client_provider_internal.go | 16 + .../jsonrpc/metadata.go => header_trailer.go} | 2 +- pkg/streamx/provider/jsonrpc/client_option.go | 16 + .../provider/jsonrpc/client_provier.go | 19 +- .../provider/jsonrpc/jsonrpc_gen_test.go | 44 ++- .../provider/jsonrpc/jsonrpc_impl_test.go | 16 + pkg/streamx/provider/jsonrpc/jsonrpc_test.go | 16 + pkg/streamx/provider/jsonrpc/protocol.go | 16 + pkg/streamx/provider/jsonrpc/server_option.go | 16 + .../provider/jsonrpc/server_provider.go | 16 + pkg/streamx/provider/jsonrpc/stream.go | 38 +- pkg/streamx/provider/jsonrpc/transport.go | 22 +- .../provider/jsonrpc/transport_test.go | 16 + .../provider/ttstream/client_option.go | 40 ++ .../provider/ttstream/client_provier.go | 91 ++++- .../provider/ttstream/client_trans_pool.go | 22 +- .../ttstream/client_trans_pool_longconn.go | 127 +++--- .../ttstream/client_trans_pool_mux.go | 42 +- .../ttstream/client_trans_pool_shortconn.go | 45 +++ .../provider/ttstream/container/linklist.go | 41 ++ .../ttstream/container/object_pool.go | 110 ++++++ .../provider/ttstream/container/pipe.go | 122 ++++++ .../provider/ttstream/container/pipe_test.go | 72 ++++ .../provider/ttstream/container/queue.go | 117 ++++++ .../provider/ttstream/container/queue_test.go | 45 +++ .../provider/ttstream/container/stack.go | 170 +++++++++ .../provider/ttstream/container/stack_test.go | 99 +++++ pkg/streamx/provider/ttstream/frame.go | 126 ++++-- .../provider/ttstream/frame_handler.go | 27 ++ pkg/streamx/provider/ttstream/frame_test.go | 46 ++- pkg/streamx/provider/ttstream/ktx/ktx.go | 71 ++++ pkg/streamx/provider/ttstream/ktx/ktx_test.go | 41 ++ .../provider/ttstream/meta_frame_handler.go | 24 +- pkg/streamx/provider/ttstream/metadata.go | 34 +- pkg/streamx/provider/ttstream/mock_test.go | 17 + pkg/streamx/provider/ttstream/scavenger.go | 42 -- .../provider/ttstream/server_option.go | 16 + .../provider/ttstream/server_provider.go | 54 ++- pkg/streamx/provider/ttstream/stream.go | 179 ++++++--- .../ttstream/stream_header_trailer.go | 59 ++- pkg/streamx/provider/ttstream/stream_io.go | 103 +++-- pkg/streamx/provider/ttstream/transport.go | 360 ++++++++++++------ .../provider/ttstream/transport_buffer.go | 132 +++++++ .../provider/ttstream/transport_test.go | 28 +- .../provider/ttstream/ttstream_client_test.go | 358 ++++++++++++++--- .../provider/ttstream/ttstream_common_test.go | 17 + .../ttstream/ttstream_gen_codec_test.go | 16 +- .../ttstream/ttstream_gen_service_test.go | 48 ++- .../provider/ttstream/ttstream_server_test.go | 87 ++++- pkg/streamx/server_provider.go | 16 + pkg/streamx/server_provider_internal.go | 16 + pkg/streamx/stream.go | 104 ++--- pkg/streamx/stream_args.go | 16 + pkg/streamx/stream_middleware.go | 16 + server/streamxserver/server.go | 2 +- server/streamxserver/server_gen.go | 4 +- server/streamxserver/server_option.go | 6 +- tool/internal_pkg/tpl/streamx/service.go | 1 + 75 files changed, 3001 insertions(+), 688 deletions(-) create mode 100644 pkg/streamx/client_options.go rename pkg/streamx/{provider/jsonrpc/metadata.go => header_trailer.go} (79%) create mode 100644 pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go create mode 100644 pkg/streamx/provider/ttstream/container/linklist.go create mode 100644 pkg/streamx/provider/ttstream/container/object_pool.go create mode 100644 pkg/streamx/provider/ttstream/container/pipe.go create mode 100644 pkg/streamx/provider/ttstream/container/pipe_test.go create mode 100644 pkg/streamx/provider/ttstream/container/queue.go create mode 100644 pkg/streamx/provider/ttstream/container/queue_test.go create mode 100644 pkg/streamx/provider/ttstream/container/stack.go create mode 100644 pkg/streamx/provider/ttstream/container/stack_test.go create mode 100644 pkg/streamx/provider/ttstream/frame_handler.go create mode 100644 pkg/streamx/provider/ttstream/ktx/ktx.go create mode 100644 pkg/streamx/provider/ttstream/ktx/ktx_test.go delete mode 100644 pkg/streamx/provider/ttstream/scavenger.go create mode 100644 pkg/streamx/provider/ttstream/transport_buffer.go diff --git a/client/client.go b/client/client.go index da41570c27..03cc7cad49 100644 --- a/client/client.go +++ b/client/client.go @@ -20,12 +20,13 @@ import ( "context" "errors" "fmt" - "github.com/cloudwego/kitex/pkg/streamx" "runtime" "runtime/debug" "strconv" "sync/atomic" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/localsession/backup" @@ -437,6 +438,13 @@ func (kc *kClient) richRemoteOption() { // (newClientStreamer: call WriteMeta before remotecli.NewClient) transInfoHdlr := bound.NewTransMetaHandler(kc.opt.MetaHandlers) kc.opt.RemoteOpt.PrependBoundHandler(transInfoHdlr) + + // add meta handlers into streaming meta handlers + for _, h := range kc.opt.MetaHandlers { + if shdlr, ok := h.(remote.StreamingMetaHandler); ok { + kc.opt.RemoteOpt.StreamingMetaHandlers = append(kc.opt.RemoteOpt.StreamingMetaHandlers, shdlr) + } + } } } @@ -454,17 +462,6 @@ func (kc *kClient) buildInvokeChain() error { return err } kc.sEps = mwchain(innerStreamingEp) - - // streamx NewStream - innerStreamXEp, err := kc.invokeStreamXEndpoint() - if err != nil { - return err - } - kc.sxEps = mwchain(innerStreamXEp) - // streamx stream call - kc.sxStreamMW = streamx.StreamMiddlewareChain(kc.opt.SMWs...) - kc.sxStreamRecvMW = streamx.StreamRecvMiddlewareChain(kc.opt.SRecvMWs...) - kc.sxStreamSendMW = streamx.StreamSendMiddlewareChain(kc.opt.SSendMWs...) return nil } @@ -758,6 +755,12 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf } } + // streamx config + sopt := opt.StreamXOptions + if sopt.RecvTimeout > 0 { + cfg.SetStreamRecvTimeout(sopt.RecvTimeout) + } + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) if callOpts != nil && callOpts.CompressorName != "" { diff --git a/client/client_streamx.go b/client/client_streamx.go index 0a84ca3fe0..64e2d8a672 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -2,8 +2,8 @@ package client import ( "context" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" - "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streamx" ) @@ -17,30 +17,6 @@ func (kc *kClient) Middlewares() (streamMW streamx.StreamMiddleware, recvMW stre return kc.sxStreamMW, kc.sxStreamRecvMW, kc.sxStreamSendMW } -// return a bottom next function -// bottom next function will create a stream and change the streamx.Args -func (kc *kClient) invokeStreamXEndpoint() (endpoint.Endpoint, error) { - // TODO: implement trans handler layer and use trans factory - //transPipl, err := newCliTransHandler(kc.opt.RemoteOpt) - //if err != nil { - // return nil, err - //} - clientProvider, _ := kc.opt.RemoteOpt.Provider.(streamx.ClientProvider) - clientProvider = streamx.NewClientProvider(clientProvider) // wrap client provider - - return func(ctx context.Context, req, resp interface{}) (err error) { - ri := rpcinfo.GetRPCInfo(ctx) - cs, err := clientProvider.NewStream(ctx, ri) - if err != nil { - return err - } - streamArgs := resp.(streamx.StreamArgs) - // 此后的中间件才会有 Stream - streamx.AsMutableStreamArgs(streamArgs).SetStream(cs) - return nil - }, nil -} - // NewStream create stream for streamx mode func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) { if !kc.inited { @@ -66,7 +42,7 @@ func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOp // put streamArgs into response arg // it's an ugly trick but if we don't want to refactor too much, // this is the only way to compatible with current endpoint design - err = kc.sxEps(ctx, req, streamArgs) + err = kc.sEps(ctx, req, streamArgs) if err != nil { return nil, err } diff --git a/client/stream.go b/client/stream.go index a184ac7ecc..0e94ea13f8 100644 --- a/client/stream.go +++ b/client/stream.go @@ -24,9 +24,9 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/endpoint" - "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/remotecli" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streaming" @@ -75,15 +75,16 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { if err != nil { return nil, err } - for _, h := range kc.opt.MetaHandlers { - if shdlr, ok := h.(remote.StreamingMetaHandler); ok { - kc.opt.RemoteOpt.StreamingMetaHandlers = append(kc.opt.RemoteOpt.StreamingMetaHandlers, shdlr) - } - } + // old version streaming mw recvEndpoint := kc.opt.Streaming.BuildRecvInvokeChain(kc.invokeRecvEndpoint()) sendEndpoint := kc.opt.Streaming.BuildSendInvokeChain(kc.invokeSendEndpoint()) + // streamx version streaming mw + kc.sxStreamMW = streamx.StreamMiddlewareChain(kc.opt.StreamXOptions.StreamMWs...) + kc.sxStreamRecvMW = streamx.StreamRecvMiddlewareChain(kc.opt.StreamXOptions.StreamRecvMWs...) + kc.sxStreamSendMW = streamx.StreamSendMiddlewareChain(kc.opt.StreamXOptions.StreamSendMWs...) + return func(ctx context.Context, req, resp interface{}) (err error) { // req and resp as &streaming.Stream ri := rpcinfo.GetRPCInfo(ctx) @@ -91,8 +92,20 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { if err != nil { return } - clientStream := newStream(st, scm, kc, ri, kc.getStreamingMode(ri), sendEndpoint, recvEndpoint) - resp.(*streaming.Result).Stream = clientStream + + // streamx API + if cs, ok := st.(streamx.Stream); ok { + streamArgs := resp.(streamx.StreamArgs) + // 此后的中间件才会有 Stream + streamx.AsMutableStreamArgs(streamArgs).SetStream(cs) + return nil + } + + // old version streaming + if cs, ok := st.(streaming.Stream); ok { + clientStream := newStream(cs, scm, kc, ri, kc.getStreamingMode(ri), sendEndpoint, recvEndpoint) + resp.(*streaming.Result).Stream = clientStream + } return }, nil } diff --git a/client/streamxclient/client.go b/client/streamxclient/client.go index d477751d48..5bd52d8888 100644 --- a/client/streamxclient/client.go +++ b/client/streamxclient/client.go @@ -11,7 +11,7 @@ type Client = client.StreamX func NewClient(svcInfo *serviceinfo.ServiceInfo, opts ...Option) (Client, error) { iopts := make([]client.Option, 0, len(opts)+1) for _, opt := range opts { - iopts = append(iopts, convertClientOption(opt)) + iopts = append(iopts, ConvertStreamXClientOption(opt)) } nopts := iclient.NewOptions(iopts) c, err := client.NewClientWithOptions(svcInfo, nopts) diff --git a/client/streamxclient/client_gen.go b/client/streamxclient/client_gen.go index 13ed0f13f1..eb1e68f530 100644 --- a/client/streamxclient/client_gen.go +++ b/client/streamxclient/client_gen.go @@ -9,10 +9,10 @@ import ( "github.com/cloudwego/kitex/pkg/streamx" ) -func InvokeStream[Header, Trailer, Req, Res any]( +func InvokeStream[Req, Res any]( ctx context.Context, cli client.StreamX, smode serviceinfo.StreamingMode, method string, req *Req, res *Res, callOptions ...streamxcallopt.CallOption, -) (stream *streamx.GenericClientStream[Header, Trailer, Req, Res], err error) { +) (stream *streamx.GenericClientStream[Req, Res], err error) { reqArgs, resArgs := streamx.NewStreamReqArgs(nil), streamx.NewStreamResArgs(nil) streamArgs := streamx.NewStreamArgs(nil) // important notes: please don't set a typed nil value into interface arg like NewStreamReqArgs({typ: *Res, ptr: nil}) @@ -28,7 +28,7 @@ func InvokeStream[Header, Trailer, Req, Res any]( if err != nil { return nil, err } - stream = streamx.NewGenericClientStream[Header, Trailer, Req, Res](cs) + stream = streamx.NewGenericClientStream[Req, Res](cs) streamx.AsMutableStreamArgs(streamArgs).SetStream(stream) streamMW, recvMW, sendMW := cli.Middlewares() @@ -39,6 +39,7 @@ func InvokeStream[Header, Trailer, Req, Res any]( // assemble streaming args depend on each stream mode switch smode { case serviceinfo.StreamingUnary: + // client should call CloseSend and server should call SendAndClose if err = stream.SendMsg(ctx, req); err != nil { return err } diff --git a/client/streamxclient/client_option.go b/client/streamxclient/client_option.go index ed9103fff4..bcc101dc63 100644 --- a/client/streamxclient/client_option.go +++ b/client/streamxclient/client_option.go @@ -1,6 +1,8 @@ package streamxclient import ( + "time" + "github.com/cloudwego/kitex/client" internal_client "github.com/cloudwego/kitex/internal/client" "github.com/cloudwego/kitex/pkg/streamx" @@ -8,14 +10,19 @@ import ( ) type Option internal_client.Option -type Options = internal_client.Options func WithHostPorts(hostports ...string) Option { - return convertInternalClientOption(client.WithHostPorts(hostports...)) + return ConvertNativeClientOption(client.WithHostPorts(hostports...)) +} + +func WithRecvTimeout(timeout time.Duration) Option { + return Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.StreamXOptions.RecvTimeout = timeout + }} } func WithDestService(destService string) Option { - return convertInternalClientOption(client.WithDestService(destService)) + return ConvertNativeClientOption(client.WithDestService(destService)) } func WithProvider(pvd streamx.ClientProvider) Option { @@ -26,26 +33,26 @@ func WithProvider(pvd streamx.ClientProvider) Option { func WithStreamMiddleware(smw streamx.StreamMiddleware) Option { return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.SMWs = append(o.SMWs, smw) + o.StreamXOptions.StreamMWs = append(o.StreamXOptions.StreamMWs, smw) }} } func WithStreamRecvMiddleware(smw streamx.StreamRecvMiddleware) Option { return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.SRecvMWs = append(o.SRecvMWs, smw) + o.StreamXOptions.StreamRecvMWs = append(o.StreamXOptions.StreamRecvMWs, smw) }} } func WithStreamSendMiddleware(smw streamx.StreamSendMiddleware) Option { return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.SSendMWs = append(o.SSendMWs, smw) + o.StreamXOptions.StreamSendMWs = append(o.StreamXOptions.StreamSendMWs, smw) }} } -func convertInternalClientOption(o internal_client.Option) Option { +func ConvertNativeClientOption(o internal_client.Option) Option { return Option{F: o.F} } -func convertClientOption(o Option) internal_client.Option { +func ConvertStreamXClientOption(o Option) internal_client.Option { return internal_client.Option{F: o.F} } diff --git a/client/streamxclient/streamxcallopt/call_option.go b/client/streamxclient/streamxcallopt/call_option.go index 2ef16ce3e3..c4acd957fd 100644 --- a/client/streamxclient/streamxcallopt/call_option.go +++ b/client/streamxclient/streamxcallopt/call_option.go @@ -7,8 +7,7 @@ import ( ) type CallOptions struct { - rpcTimeout time.Duration - ProviderOption any + RPCTimeout time.Duration } type CallOption struct { @@ -20,6 +19,6 @@ type WithCallOption func(o *CallOption) func WithRPCTimeout(rpcTimeout time.Duration) CallOption { return CallOption{f: func(o *CallOptions, di *strings.Builder) { di.WriteString(fmt.Sprintf("WithRPCTimeout(%d)", rpcTimeout)) - o.rpcTimeout = rpcTimeout + o.RPCTimeout = rpcTimeout }} } diff --git a/go.sum b/go.sum index 0e40c50be3..806d9d39d1 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZ github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= -github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4 h1:SHw9GUBBcAnLWeK2MtPH7O6YQG9Q2ZZ8koD/4alpLvE= -github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= +github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682 h1:hj/AhlEngERp5Tjt864veEvyK6RglXKcXpxkIOSRfug= +github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= <<<<<<< HEAD @@ -34,11 +34,14 @@ github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLi ======= github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= github.com/cloudwego/localsession v0.0.2/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= +<<<<<<< HEAD github.com/cloudwego/netpoll v0.6.5-0.20240905095957-e6ec47be2fe0 h1:2aoCxK8fee7LhwWveg3ORVEDBoMtmTY2NuSAtNGpnFI= github.com/cloudwego/netpoll v0.6.5-0.20240905095957-e6ec47be2fe0/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= <<<<<<< HEAD >>>>>>> 9e44721 (feat: support multi service (#1538)) ======= +======= +>>>>>>> 2b0e374 (perf: optimise pipe and queue locker) github.com/cloudwego/netpoll v0.6.5-0.20240911073319-2ec9568b10cf h1:c/K4XrkloCgZp+En3LjbXtqfr0KQwC85utUvdDm76V4= github.com/cloudwego/netpoll v0.6.5-0.20240911073319-2ec9568b10cf/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= >>>>>>> 968bdfc (chore: fix unit test (#1545)) diff --git a/internal/client/option.go b/internal/client/option.go index 968b379b82..7452facd0b 100644 --- a/internal/client/option.go +++ b/internal/client/option.go @@ -19,9 +19,10 @@ package client import ( "context" - "github.com/cloudwego/kitex/pkg/streamx" "time" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/localsession/backup" "github.com/cloudwego/kitex/internal/configutil" @@ -83,11 +84,8 @@ type Options struct { ACLRules []acl.RejectFunc - MWBs []endpoint.MiddlewareBuilder - IMWBs []endpoint.MiddlewareBuilder - SMWs []streamx.StreamMiddleware - SRecvMWs []streamx.StreamRecvMiddleware - SSendMWs []streamx.StreamSendMiddleware + MWBs []endpoint.MiddlewareBuilder + IMWBs []endpoint.MiddlewareBuilder Bus event.Bus Events event.Queue @@ -123,7 +121,8 @@ type Options struct { // Context backup CtxBackupHandler backup.BackupHandler - Streaming stream.StreamingConfig + Streaming stream.StreamingConfig + StreamXOptions streamx.ClientOptions } // Apply applies all options. diff --git a/internal/test/port.go b/internal/test/port.go index 4193c67cd3..5a620c13f9 100644 --- a/internal/test/port.go +++ b/internal/test/port.go @@ -47,7 +47,7 @@ func GetLocalAddress() string { for { time.Sleep(time.Millisecond * time.Duration(1+rand.Intn(10))) port := atomic.AddUint32(&curPort, 1+uint32(rand.Intn(10))) - addr := "localhost:" + strconv.Itoa(int(port)) + addr := "127.0.0.1:" + strconv.Itoa(int(port)) if !IsAddressInUse(addr) { trace := strings.Split(string(debug.Stack()), "\n") if len(trace) > 6 { diff --git a/pkg/remote/remotecli/stream.go b/pkg/remote/remotecli/stream.go index 344a2d87b3..6bf541329d 100644 --- a/pkg/remote/remotecli/stream.go +++ b/pkg/remote/remotecli/stream.go @@ -23,12 +23,11 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/pkg/streamx" ) // NewStream create a client side stream -func NewStream(ctx context.Context, ri rpcinfo.RPCInfo, handler remote.ClientTransHandler, opt *remote.ClientOption) (streaming.Stream, *StreamConnManager, error) { - cm := NewConnWrapper(opt.ConnPool) +func NewStream(ctx context.Context, ri rpcinfo.RPCInfo, handler remote.ClientTransHandler, opt *remote.ClientOption) (any, *StreamConnManager, error) { var err error for _, shdlr := range opt.StreamingMetaHandlers { ctx, err = shdlr.OnConnectStream(ctx) @@ -36,6 +35,20 @@ func NewStream(ctx context.Context, ri rpcinfo.RPCInfo, handler remote.ClientTra return nil, nil, err } } + + clientProvider, ok := opt.Provider.(streamx.ClientProvider) + if ok { + // wrap client provider + clientProvider = streamx.NewClientProvider(clientProvider) + cs, err := clientProvider.NewStream(ctx, ri) + if err != nil { + return nil, nil, err + } + return cs, nil, nil + } + + // old version streaming + cm := NewConnWrapper(opt.ConnPool) rawConn, err := cm.GetConn(ctx, opt.Dialer, ri) if err != nil { return nil, nil, err diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index 7e97ed0955..011d555a7e 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -25,6 +25,7 @@ import ( "runtime/debug" "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -92,21 +93,19 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { for { nctx, ss, nerr := t.provider.OnStream(ctx, conn) if nerr != nil { - if !errors.Is(nerr, io.EOF) { - klog.CtxErrorf(ctx, "KITEX: OnStream failed: err=%v", nerr) + if errors.Is(nerr, io.EOF) { + return nil } + klog.CtxErrorf(ctx, "KITEX: OnStream failed: err=%v", nerr) return nerr } // stream level goroutine - go func() { - nerr = t.OnStream(nctx, conn, ss) - if nerr != nil { - if !errors.Is(nerr, io.EOF) { - klog.CtxErrorf(ctx, "KITEX: stream ReadStream failed: err=%v", nerr) - } - return + gofunc.GoFunc(ctx, func() { + err := t.OnStream(nctx, conn, ss) + if err != nil && !errors.Is(err, io.EOF) { + klog.CtxErrorf(ctx, "KITEX: stream ReadStream failed: err=%v", nerr) } - }() + }) } } @@ -157,7 +156,6 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream reqArgs := streamx.NewStreamReqArgs(nil) resArgs := streamx.NewStreamResArgs(nil) - // server handler (which will call streamxserver.InvokeStream inside) serr := t.inkHdlFunc(ctx, reqArgs, resArgs) ctx, err = t.provider.OnStreamFinish(ctx, ss) if err == nil && serr != nil { diff --git a/pkg/rpcinfo/interface.go b/pkg/rpcinfo/interface.go index 05b8f4d379..c8249e8605 100644 --- a/pkg/rpcinfo/interface.go +++ b/pkg/rpcinfo/interface.go @@ -73,9 +73,14 @@ type TimeoutProvider interface { Timeouts(ri RPCInfo) Timeouts } +type StreamConfig interface { + StreamRecvTimeout() time.Duration +} + // RPCConfig contains configuration for RPC. type RPCConfig interface { Timeouts + StreamConfig IOBufferSize() int TransportProtocol() transport.Protocol InteractionMode() InteractionMode diff --git a/pkg/rpcinfo/mutable.go b/pkg/rpcinfo/mutable.go index 9c6a6c8802..ee456fe50d 100644 --- a/pkg/rpcinfo/mutable.go +++ b/pkg/rpcinfo/mutable.go @@ -52,6 +52,8 @@ type MutableRPCConfig interface { CopyFrom(from RPCConfig) ImmutableView() RPCConfig SetPayloadCodec(codec serviceinfo.PayloadCodec) + + SetStreamRecvTimeout(timeout time.Duration) } // MutableRPCStats is used to change the information in the RPCStats. diff --git a/pkg/rpcinfo/rpcconfig.go b/pkg/rpcinfo/rpcconfig.go index 3c0f654ca6..4159e8dae2 100644 --- a/pkg/rpcinfo/rpcconfig.go +++ b/pkg/rpcinfo/rpcconfig.go @@ -66,6 +66,9 @@ type rpcConfig struct { transportProtocol transport.Protocol interactionMode InteractionMode payloadCodec serviceinfo.PayloadCodec + + // stream config + streamRecvTimeout time.Duration } func init() { @@ -193,6 +196,14 @@ func (r *rpcConfig) PayloadCodec() serviceinfo.PayloadCodec { return r.payloadCodec } +func (r *rpcConfig) SetStreamRecvTimeout(timeout time.Duration) { + r.streamRecvTimeout = timeout +} + +func (r *rpcConfig) StreamRecvTimeout() time.Duration { + return r.streamRecvTimeout +} + // Clone returns a copy of the current rpcConfig. func (r *rpcConfig) Clone() MutableRPCConfig { r2 := rpcConfigPool.Get().(*rpcConfig) diff --git a/pkg/streamx/client_options.go b/pkg/streamx/client_options.go new file mode 100644 index 0000000000..333fa46eb2 --- /dev/null +++ b/pkg/streamx/client_options.go @@ -0,0 +1,10 @@ +package streamx + +import "time" + +type ClientOptions struct { + RecvTimeout time.Duration + StreamMWs []StreamMiddleware + StreamRecvMWs []StreamRecvMiddleware + StreamSendMWs []StreamSendMiddleware +} diff --git a/pkg/streamx/client_provider.go b/pkg/streamx/client_provider.go index 0f473e085b..a4543ef2f9 100644 --- a/pkg/streamx/client_provider.go +++ b/pkg/streamx/client_provider.go @@ -1,7 +1,24 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( "context" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/rpcinfo" ) diff --git a/pkg/streamx/client_provider_internal.go b/pkg/streamx/client_provider_internal.go index 53d9968734..d28b83d55b 100644 --- a/pkg/streamx/client_provider_internal.go +++ b/pkg/streamx/client_provider_internal.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( diff --git a/pkg/streamx/provider/jsonrpc/metadata.go b/pkg/streamx/header_trailer.go similarity index 79% rename from pkg/streamx/provider/jsonrpc/metadata.go rename to pkg/streamx/header_trailer.go index 595227a6c7..449bf99847 100644 --- a/pkg/streamx/provider/jsonrpc/metadata.go +++ b/pkg/streamx/header_trailer.go @@ -1,4 +1,4 @@ -package jsonrpc +package streamx type Header map[string]string type Trailer map[string]string diff --git a/pkg/streamx/provider/jsonrpc/client_option.go b/pkg/streamx/provider/jsonrpc/client_option.go index dc434cd382..3a35f5c1f0 100644 --- a/pkg/streamx/provider/jsonrpc/client_option.go +++ b/pkg/streamx/provider/jsonrpc/client_option.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc type ClientProviderOption func(cp *clientProvider) diff --git a/pkg/streamx/provider/jsonrpc/client_provier.go b/pkg/streamx/provider/jsonrpc/client_provier.go index aa4e9fec6c..304187a1be 100644 --- a/pkg/streamx/provider/jsonrpc/client_provier.go +++ b/pkg/streamx/provider/jsonrpc/client_provier.go @@ -1,13 +1,30 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc import ( "context" + "net" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" - "net" ) var _ streamx.ClientProvider = (*clientProvider)(nil) diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go index e872ecfc6f..afae42d52a 100644 --- a/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go +++ b/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc_test import ( @@ -13,19 +29,19 @@ import ( // === gen code === -type ClientStreamingServer[Req, Res any] streamx.ClientStreamingServer[jsonrpc.Header, jsonrpc.Trailer, Req, Res] -type ServerStreamingServer[Res any] streamx.ServerStreamingServer[jsonrpc.Header, jsonrpc.Trailer, Res] -type BidiStreamingServer[Req, Res any] streamx.BidiStreamingServer[jsonrpc.Header, jsonrpc.Trailer, Req, Res] -type ClientStreamingClient[Req, Res any] streamx.ClientStreamingClient[jsonrpc.Header, jsonrpc.Trailer, Req, Res] -type ServerStreamingClient[Res any] streamx.ServerStreamingClient[jsonrpc.Header, jsonrpc.Trailer, Res] -type BidiStreamingClient[Req, Res any] streamx.BidiStreamingClient[jsonrpc.Header, jsonrpc.Trailer, Req, Res] +type ClientStreamingServer[Req, Res any] streamx.ClientStreamingServer[Req, Res] +type ServerStreamingServer[Res any] streamx.ServerStreamingServer[Res] +type BidiStreamingServer[Req, Res any] streamx.BidiStreamingServer[Req, Res] +type ClientStreamingClient[Req, Res any] streamx.ClientStreamingClient[Req, Res] +type ServerStreamingClient[Res any] streamx.ServerStreamingClient[Res] +type BidiStreamingClient[Req, Res any] streamx.BidiStreamingClient[Req, Res] var serviceInfo = &serviceinfo.ServiceInfo{ ServiceName: "a.b.c", Methods: map[string]serviceinfo.MethodInfo{ "Unary": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + return streamxserver.InvokeStream[Request, Response]( ctx, serviceinfo.StreamingUnary, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, @@ -35,7 +51,7 @@ var serviceInfo = &serviceinfo.ServiceInfo{ ), "ClientStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + return streamxserver.InvokeStream[Request, Response]( ctx, serviceinfo.StreamingClient, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, @@ -45,7 +61,7 @@ var serviceInfo = &serviceinfo.ServiceInfo{ ), "ServerStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + return streamxserver.InvokeStream[Request, Response]( ctx, serviceinfo.StreamingServer, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, @@ -55,7 +71,7 @@ var serviceInfo = &serviceinfo.ServiceInfo{ ), "BidiStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + return streamxserver.InvokeStream[Request, Response]( ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, @@ -135,7 +151,7 @@ type kClient struct { func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { res := new(Response) - _, err := streamxclient.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + _, err := streamxclient.InvokeStream[Request, Response]( ctx, c.Client, serviceinfo.StreamingUnary, "Unary", req, res, callOptions...) if err != nil { return nil, err @@ -144,18 +160,18 @@ func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...stream } func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream ClientStreamingClient[Request, Response], err error) { - return streamxclient.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + return streamxclient.InvokeStream[Request, Response]( ctx, c.Client, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) } func (c *kClient) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( stream ServerStreamingClient[Response], err error) { - return streamxclient.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + return streamxclient.InvokeStream[Request, Response]( ctx, c.Client, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) } func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( stream BidiStreamingClient[Request, Response], err error) { - return streamxclient.InvokeStream[jsonrpc.Header, jsonrpc.Trailer, Request, Response]( + return streamxclient.InvokeStream[Request, Response]( ctx, c.Client, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) } diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go index b4019c9575..c3ba1b7644 100644 --- a/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go +++ b/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc_test import ( diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_test.go index f7dc9e7f65..ecfa8c3f6e 100644 --- a/pkg/streamx/provider/jsonrpc/jsonrpc_test.go +++ b/pkg/streamx/provider/jsonrpc/jsonrpc_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc_test import ( diff --git a/pkg/streamx/provider/jsonrpc/protocol.go b/pkg/streamx/provider/jsonrpc/protocol.go index a6a1d29523..3d8654ea29 100644 --- a/pkg/streamx/provider/jsonrpc/protocol.go +++ b/pkg/streamx/provider/jsonrpc/protocol.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc import ( diff --git a/pkg/streamx/provider/jsonrpc/server_option.go b/pkg/streamx/provider/jsonrpc/server_option.go index 6282ebe446..0cc2ae4017 100644 --- a/pkg/streamx/provider/jsonrpc/server_option.go +++ b/pkg/streamx/provider/jsonrpc/server_option.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc type ServerProviderOption func(pc *serverProvider) diff --git a/pkg/streamx/provider/jsonrpc/server_provider.go b/pkg/streamx/provider/jsonrpc/server_provider.go index f38a7e0834..f689c6b7c5 100644 --- a/pkg/streamx/provider/jsonrpc/server_provider.go +++ b/pkg/streamx/provider/jsonrpc/server_provider.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc import ( diff --git a/pkg/streamx/provider/jsonrpc/stream.go b/pkg/streamx/provider/jsonrpc/stream.go index c9bf141905..6e36b95d46 100644 --- a/pkg/streamx/provider/jsonrpc/stream.go +++ b/pkg/streamx/provider/jsonrpc/stream.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc import ( @@ -9,10 +25,10 @@ import ( ) var ( - _ streamx.ClientStream = (*clientStream)(nil) - _ streamx.ServerStream = (*serverStream)(nil) - _ streamx.ClientStreamMetadata[Header, Trailer] = (*clientStream)(nil) - _ streamx.ServerStreamMetadata[Header, Trailer] = (*serverStream)(nil) + _ streamx.ClientStream = (*clientStream)(nil) + _ streamx.ServerStream = (*serverStream)(nil) + _ streamx.ClientStreamMetadata = (*clientStream)(nil) + _ streamx.ServerStreamMetadata = (*serverStream)(nil) ) func newStream(trans *transport, sid int, mode streamx.StreamingMode, service, method string) (s *stream) { @@ -35,12 +51,12 @@ type stream struct { trans *transport } -func (s *stream) Header() (Header, error) { - return make(Header), nil +func (s *stream) Header() (streamx.Header, error) { + return make(streamx.Header), nil } -func (s *stream) Trailer() (Trailer, error) { - return make(Trailer), nil +func (s *stream) Trailer() (streamx.Trailer, error) { + return make(streamx.Trailer), nil } func (s *stream) Mode() streamx.StreamingMode { @@ -109,14 +125,14 @@ type serverStream struct { *stream } -func (s *serverStream) SetHeader(hd Header) error { +func (s *serverStream) SetHeader(hd streamx.Header) error { return nil } -func (s *serverStream) SendHeader(hd Header) error { +func (s *serverStream) SendHeader(hd streamx.Header) error { return nil } -func (s *serverStream) SetTrailer(hd Trailer) error { +func (s *serverStream) SetTrailer(hd streamx.Trailer) error { return nil } diff --git a/pkg/streamx/provider/jsonrpc/transport.go b/pkg/streamx/provider/jsonrpc/transport.go index d36b6131c8..e761f89083 100644 --- a/pkg/streamx/provider/jsonrpc/transport.go +++ b/pkg/streamx/provider/jsonrpc/transport.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc import ( @@ -49,7 +65,11 @@ func newTransport(sinfo *serviceinfo.ServiceInfo, conn net.Conn) *transport { } func (t *transport) close() (err error) { - close(t.stop) + select { + case <-t.stop: + default: + close(t.stop) + } return nil } diff --git a/pkg/streamx/provider/jsonrpc/transport_test.go b/pkg/streamx/provider/jsonrpc/transport_test.go index 4df6e66ac4..0327eb12f6 100644 --- a/pkg/streamx/provider/jsonrpc/transport_test.go +++ b/pkg/streamx/provider/jsonrpc/transport_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package jsonrpc import ( diff --git a/pkg/streamx/provider/ttstream/client_option.go b/pkg/streamx/provider/ttstream/client_option.go index 6d0e3a4db3..d27fb6c3fe 100644 --- a/pkg/streamx/provider/ttstream/client_option.go +++ b/pkg/streamx/provider/ttstream/client_option.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream type ClientProviderOption func(cp *clientProvider) @@ -7,3 +23,27 @@ func WithClientMetaHandler(metaHandler MetaFrameHandler) ClientProviderOption { cp.metaHandler = metaHandler } } + +func WithClientHeaderHandler(handler HeaderFrameHandler) ClientProviderOption { + return func(cp *clientProvider) { + cp.headerHandler = handler + } +} + +func WithClientLongConnPool(config LongConnConfig) ClientProviderOption { + return func(cp *clientProvider) { + cp.transPool = newLongConnTransPool(config) + } +} + +func WithClientShortConnPool() ClientProviderOption { + return func(cp *clientProvider) { + cp.transPool = newShortConnTransPool() + } +} + +func WithClientMuxConnPool() ClientProviderOption { + return func(cp *clientProvider) { + cp.transPool = newMuxTransPool() + } +} diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index de3a51a188..0762bf30f6 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -1,17 +1,34 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( "context" "runtime" + "sync/atomic" "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/gopkg/protocol/ttheader" "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/kerrors" - "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" ) var _ streamx.ClientProvider = (*clientProvider)(nil) @@ -19,20 +36,22 @@ var _ streamx.ClientProvider = (*clientProvider)(nil) func NewClientProvider(sinfo *serviceinfo.ServiceInfo, opts ...ClientProviderOption) (streamx.ClientProvider, error) { cp := new(clientProvider) cp.sinfo = sinfo + cp.transPool = newMuxTransPool() for _, opt := range opts { opt(cp) } - cp.transPool = newMuxTransPool(sinfo) return cp, nil } type clientProvider struct { - transPool transPool - sinfo *serviceinfo.ServiceInfo - metaHandler MetaFrameHandler + transPool transPool + sinfo *serviceinfo.ServiceInfo + metaHandler MetaFrameHandler + headerHandler HeaderFrameHandler } func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) { + rconfig := ri.Config() invocation := ri.Invocation() method := invocation.MethodName() addr := ri.To().Address() @@ -40,26 +59,64 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callO return nil, kerrors.ErrNoDestAddress } - trans, err := c.transPool.Get(addr.Network(), addr.String()) + var strHeader streamx.Header + var intHeader IntHeader + var err error + if c.headerHandler != nil { + intHeader, strHeader, err = c.headerHandler.OnStream(ctx) + if err != nil { + return nil, err + } + } else { + intHeader = IntHeader{} + strHeader = map[string]string{} + } + strHeader[ttheader.HeaderIDLServiceName] = c.sinfo.ServiceName + metainfo.SaveMetaInfoToMap(ctx, strHeader) + + trans, err := c.transPool.Get(c.sinfo, addr.Network(), addr.String()) if err != nil { return nil, err } - header := map[string]string{ - ttheader.HeaderIDLServiceName: c.sinfo.ServiceName, - } - metainfo.SaveMetaInfoToMap(ctx, header) - s, err := trans.newStream(ctx, method, header) + sio, err := trans.newStreamIO(ctx, method, intHeader, strHeader) if err != nil { return nil, err } + sio.stream.setRecvTimeout(rconfig.StreamRecvTimeout()) // only client can set meta frame handler - s.setMetaFrameHandler(c.metaHandler) - cs := newClientStream(s) - runtime.SetFinalizer(cs, func(cs *clientStream) { - klog.Debugf("client stream[%v] closing", cs.sid) - _ = cs.close() - c.transPool.Put(trans) + sio.stream.setMetaFrameHandler(c.metaHandler) + + // if ctx from server side, we should cancel the stream when server handler already returned + // TODO: this canceling transmit should be configurable + ktx.RegisterCancelCallback(ctx, func() { + sio.stream.cancel() + }) + + cs := newClientStream(sio.stream) + // the END of a client stream means it should send and recv trailer and not hold by user anymore + var ended uint32 + sio.setEOFCallback(func() { + // if stream is ended by both parties, put the transport back to pool + sio.stream.close() + if atomic.AddUint32(&ended, 1) == 2 { + if trans.IsActive() { + c.transPool.Put(trans) + } + err = trans.streamDelete(sio.stream.sid) + } + }) + runtime.SetFinalizer(cs, func(cstream *clientStream) { + // it's safe to call CloseSend twice + // we do repeated CloseSend here to ensure stream can be closed normally + _ = cstream.CloseSend(ctx) + // only delete stream when clientStream be finalized + if atomic.AddUint32(&ended, 1) == 2 { + if trans.IsActive() { + c.transPool.Put(trans) + } + err = trans.streamDelete(sio.stream.sid) + } }) return cs, err } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool.go b/pkg/streamx/provider/ttstream/client_trans_pool.go index 57eb0bbbd4..f694b9db3c 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool.go @@ -1,10 +1,24 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream -// TODO: it's to complex for users implement idle check -// so let's implement it in netpoll +import "github.com/cloudwego/kitex/pkg/serviceinfo" -// TODO: make transPoll configurable type transPool interface { - Get(network string, addr string) (trans *transport, err error) + Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) Put(trans *transport) } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go index 8a54f332c7..c7657ea929 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go @@ -1,102 +1,75 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( - "runtime" - "sync" "time" "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" "github.com/cloudwego/netpoll" ) -func newTransPool(sinfo *serviceinfo.ServiceInfo) *longConnTransPool { - tp := &longConnTransPool{sinfo: sinfo} - go func() { - now := time.Now() - deleteKeys := make([]string, 0) - tp.pool.Range(func(addr, value any) bool { - tstack := value.(*transStack) - duration := now.Sub(tstack.modified) - if duration >= time.Minute*10 { - deleteKeys = append(deleteKeys, addr.(string)) - } - return true - }) - }() +var DefaultLongConnConfig = LongConnConfig{ + MaxIdleTimeout: time.Minute, +} + +type LongConnConfig struct { + MaxIdleTimeout time.Duration +} + +func newLongConnTransPool(config LongConnConfig) transPool { + tp := new(longConnTransPool) + tp.config = DefaultLongConnConfig + if config.MaxIdleTimeout > 0 { + tp.config.MaxIdleTimeout = config.MaxIdleTimeout + } + tp.transPool = container.NewObjectPool(tp.config.MaxIdleTimeout) return tp } type longConnTransPool struct { - pool sync.Map // {"addr":*transStack} - sinfo *serviceinfo.ServiceInfo + transPool *container.ObjectPool + config LongConnConfig } -func (c *longConnTransPool) Get(network string, addr string) (trans *transport, err error) { - var cstack *transStack - val, ok := c.pool.Load(addr) - if !ok { - // TODO: here may have a race problem - cstack = newTransStack() - _, _ = c.pool.LoadOrStore(addr, cstack) - } else { - cstack = val.(*transStack) - } - trans = cstack.Pop() - if trans != nil { - return trans, nil +func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) { + for { + o := c.transPool.Pop(addr) + if o == nil { + break + } + trans = o.(*transport) + if trans.IsActive() { + return trans, nil + } } + + // create new connection conn, err := netpoll.DialConnection(network, addr, time.Second) if err != nil { return nil, err } - trans = newTransport(clientTransport, c.sinfo, conn) - runtime.SetFinalizer(trans, func(t *transport) { t.Close() }) + trans = newTransport(clientTransport, sinfo, conn) + // create new transport return trans, nil } func (c *longConnTransPool) Put(trans *transport) { - var cstack *transStack - val, ok := c.pool.Load(trans.conn.RemoteAddr()) - if !ok { - return - } - cstack = val.(*transStack) - cstack.Push(trans) -} - -func newTransStack() *transStack { - return &transStack{} -} - -// FILO -type transStack struct { - mu sync.Mutex - stack []*transport // TODO: now it's a mem leak stack implementation - modified time.Time -} - -func (s *transStack) Pop() (trans *transport) { - s.mu.Lock() - if len(s.stack) == 0 { - s.mu.Unlock() - return nil - } - trans = s.stack[len(s.stack)-1] - s.stack = s.stack[:len(s.stack)-1] - s.mu.Unlock() - return trans -} - -func (s *transStack) Push(trans *transport) { - s.mu.Lock() - s.stack = append(s.stack, trans) - s.modified = time.Now() - s.mu.Unlock() -} - -func (s *transStack) Clear() { - s.mu.Lock() - s.stack = []*transport{} - s.modified = time.Now() - s.mu.Unlock() + addr := trans.conn.RemoteAddr().String() + c.transPool.Push(addr, trans) } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go index b08677a274..4775cb40bc 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go @@ -1,6 +1,23 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( + "runtime" "sync" "time" @@ -11,21 +28,17 @@ import ( var _ transPool = (*muxTransPool)(nil) -func newMuxTransPool(sinfo *serviceinfo.ServiceInfo) transPool { +func newMuxTransPool() transPool { t := new(muxTransPool) - t.sinfo = sinfo - t.scavenger = newScavenger() return t } type muxTransPool struct { - sinfo *serviceinfo.ServiceInfo - pool sync.Map // addr:netpoll.Connection - scavenger *scavenger - sflight singleflight.Group + pool sync.Map // addr:*transport + sflight singleflight.Group } -func (m *muxTransPool) Get(network string, addr string) (trans *transport, err error) { +func (m *muxTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) { v, ok := m.pool.Load(addr) if ok { return v.(*transport), nil @@ -35,8 +48,17 @@ func (m *muxTransPool) Get(network string, addr string) (trans *transport, err e if err != nil { return nil, err } - trans = newTransport(clientTransport, m.sinfo, conn) - m.scavenger.Add(trans) + trans = newTransport(clientTransport, sinfo, conn) + _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { + // peer close + _ = trans.Close() + return nil + }) + m.pool.Store(addr, trans) + runtime.SetFinalizer(trans, func(trans *transport) { + // self close when not hold by user + _ = trans.Close() + }) return trans, nil }) if err != nil { diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go new file mode 100644 index 0000000000..37c803b3c7 --- /dev/null +++ b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go @@ -0,0 +1,45 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "time" + + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/netpoll" +) + +func newShortConnTransPool() transPool { + return &shortConnTransPool{} +} + +type shortConnTransPool struct{} + +func (c *shortConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (*transport, error) { + // create new connection + conn, err := netpoll.DialConnection(network, addr, time.Second) + if err != nil { + return nil, err + } + // create new transport + trans := newTransport(clientTransport, sinfo, conn) + return trans, nil +} + +func (c *shortConnTransPool) Put(trans *transport) { + _ = trans.Close() +} diff --git a/pkg/streamx/provider/ttstream/container/linklist.go b/pkg/streamx/provider/ttstream/container/linklist.go new file mode 100644 index 0000000000..b9b26c4627 --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/linklist.go @@ -0,0 +1,41 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package container + +type linkNode[ValueType any] struct { + val ValueType + next *linkNode[ValueType] +} + +func (n *linkNode[ValueType]) reset() { + var nilVal ValueType + n.val = nilVal + n.next = nil +} + +type doubleLinkNode[ValueType any] struct { + val ValueType + next *doubleLinkNode[ValueType] + last *doubleLinkNode[ValueType] +} + +func (n *doubleLinkNode[ValueType]) reset() { + var nilVal ValueType + n.val = nilVal + n.next = nil + n.last = nil +} diff --git a/pkg/streamx/provider/ttstream/container/object_pool.go b/pkg/streamx/provider/ttstream/container/object_pool.go new file mode 100644 index 0000000000..26fdc26978 --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/object_pool.go @@ -0,0 +1,110 @@ +package container + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/kitex/pkg/klog" +) + +type Object interface { + Close() error +} + +type objectItem struct { + object Object + lastActive time.Time +} + +func NewObjectPool(idleTimeout time.Duration) *ObjectPool { + s := new(ObjectPool) + s.idleTimeout = idleTimeout + s.objects = make(map[string]*Stack[objectItem]) + go s.cleaning() + return s +} + +type ObjectPool struct { + L sync.Mutex // STW + objects map[string]*Stack[objectItem] + idleTimeout time.Duration + closed int32 +} + +func (s *ObjectPool) Push(key string, o Object) { + s.L.Lock() + stk := s.objects[key] + if stk == nil { + stk = NewStack[objectItem]() + s.objects[key] = stk + } + s.L.Unlock() + stk.Push(objectItem{object: o, lastActive: time.Now()}) +} + +func (s *ObjectPool) Pop(key string) Object { + s.L.Lock() + stk := s.objects[key] + s.L.Unlock() + if stk == nil { + return nil + } + o, ok := stk.Pop() + if !ok { + return nil + } + return o.object +} + +func (s *ObjectPool) Close() { + atomic.CompareAndSwapInt32(&s.closed, 0, 1) +} + +func (s *ObjectPool) cleaning() { + cleanInternal := time.Second + for atomic.LoadInt32(&s.closed) == 0 { + time.Sleep(cleanInternal) + + now := time.Now() + s.L.Lock() + // update cleanInternal + objSize := 0 + for _, stk := range s.objects { + objSize += stk.Size() + } + cleanInternal = time.Second + time.Duration(objSize)*time.Millisecond*10 + if cleanInternal > time.Second*10 { + cleanInternal = time.Second * 10 + } + // clean objects + for key, stk := range s.objects { + deleted := 0 + var oldest *time.Time + klog.Infof("object[%s] pool cleaning %d objects", key, stk.Size()) + stk.RangeDelete(func(o objectItem) (deleteNode bool, continueRange bool) { + if oldest == nil { + oldest = &o.lastActive + } + if o.object == nil { + deleted++ + return true, true + } + + // RangeDelete start from the stack bottom + // we assume that the values on the top of last valid value are all valid + if now.Sub(o.lastActive) < s.idleTimeout { + return false, false + } + deleted++ + err := o.object.Close() + klog.Infof("object is invalid: lastActive=%s, closedErr=%v", o.lastActive.String(), err) + return true, true + }) + if oldest != nil { + klog.Infof("object[%s] pool deleted %d objects, oldest=%s", key, deleted, oldest.String()) + } + } + s.L.Unlock() + } +} diff --git a/pkg/streamx/provider/ttstream/container/pipe.go b/pkg/streamx/provider/ttstream/container/pipe.go new file mode 100644 index 0000000000..15c92590ed --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/pipe.go @@ -0,0 +1,122 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package container + +import ( + "context" + "fmt" + "io" + "sync/atomic" +) + +type pipeState = int32 + +const ( + pipeStateInactive pipeState = 0 + pipeStateActive pipeState = 1 + pipeStateClosed pipeState = 2 + pipeStateCanceled pipeState = 3 +) + +var ErrPipeEOF = io.EOF +var ErrPipeCanceled = fmt.Errorf("pipe canceled") +var stateErrors map[pipeState]error = map[pipeState]error{ + pipeStateClosed: ErrPipeEOF, + pipeStateCanceled: ErrPipeCanceled, +} + +// Pipe implement a queue that never block on Write but block on Read if there is nothing to read +type Pipe[Item any] struct { + queue *Queue[Item] + trigger chan struct{} + state pipeState +} + +func NewPipe[Item any]() *Pipe[Item] { + p := new(Pipe[Item]) + p.queue = NewQueue[Item]() + p.trigger = make(chan struct{}, 1) + return p +} + +// Read will block if there is nothing to read +func (p *Pipe[Item]) Read(ctx context.Context, items []Item) (int, error) { +READ: + var n int + for i := 0; i < len(items); i++ { + val, ok := p.queue.Get() + if !ok { + break + } + items[i] = val + n++ + } + if n > 0 { + return n, nil + } + + // no data to read, waiting writes + for { + if ctx.Done() != nil { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-p.trigger: + } + } else { + <-p.trigger + } + + if p.queue.Size() == 0 { + err := stateErrors[atomic.LoadInt32(&p.state)] + if err != nil { + return 0, err + } + } + goto READ + } +} + +func (p *Pipe[Item]) Write(ctx context.Context, items ...Item) error { + if !atomic.CompareAndSwapInt32(&p.state, pipeStateInactive, pipeStateActive) && atomic.LoadInt32(&p.state) != pipeStateActive { + err := stateErrors[atomic.LoadInt32(&p.state)] + if err != nil { + return err + } + return fmt.Errorf("unknown state error") + } + + for _, item := range items { + p.queue.Add(item) + } + // wake up + select { + case p.trigger <- struct{}{}: + default: + } + return nil +} + +func (p *Pipe[Item]) Close() { + atomic.StoreInt32(&p.state, pipeStateClosed) + close(p.trigger) +} + +func (p *Pipe[Item]) Cancel() { + atomic.StoreInt32(&p.state, pipeStateCanceled) + close(p.trigger) +} diff --git a/pkg/streamx/provider/ttstream/container/pipe_test.go b/pkg/streamx/provider/ttstream/container/pipe_test.go new file mode 100644 index 0000000000..761589d30b --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/pipe_test.go @@ -0,0 +1,72 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package container + +import ( + "context" + "sync" + "testing" +) + +func TestPipeline(t *testing.T) { + ctx := context.Background() + pipe := NewPipe[int]() + var recv int + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + items := make([]int, 10) + for { + n, err := pipe.Read(ctx, items) + if err != nil { + return + } + for i := 0; i < n; i++ { + recv += items[i] + } + } + }() + round := 10000 + itemsPerRound := []int{1, 1, 1, 1, 1} + for i := 0; i < round; i++ { + _ = pipe.Write(ctx, itemsPerRound...) + } + pipe.Close() + wg.Wait() + if recv != len(itemsPerRound)*round { + t.Fatalf("expect %d items, got %d", len(itemsPerRound)*round, recv) + } +} + +func BenchmarkPipeline(b *testing.B) { + ctx := context.Background() + pipe := NewPipe[int]() + readCache := make([]int, 8) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for j := 0; j < len(readCache); j++ { + go pipe.Write(ctx, 1) + } + total := 0 + for total < len(readCache) { + n, _ := pipe.Read(ctx, readCache) + total += n + } + } +} diff --git a/pkg/streamx/provider/ttstream/container/queue.go b/pkg/streamx/provider/ttstream/container/queue.go new file mode 100644 index 0000000000..0a074d2d27 --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/queue.go @@ -0,0 +1,117 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package container + +import ( + "runtime" + "sync" + "sync/atomic" +) + +func NewQueue[ValueType any]() *Queue[ValueType] { + return &Queue[ValueType]{} +} + +// Queue implement a concurrent-safe queue +type Queue[ValueType any] struct { + head *linkNode[ValueType] // head will be protected by Locker + tail *linkNode[ValueType] // tail will be protected by Locker + read *linkNode[ValueType] // read can only access by func Get() + safety int32 + size int32 + + nodePool sync.Pool +} + +func (q *Queue[ValueType]) lock() { + for !atomic.CompareAndSwapInt32(&q.safety, 0, 1) { + runtime.Gosched() + } +} + +func (q *Queue[ValueType]) unlock() { + atomic.StoreInt32(&q.safety, 0) +} + +func (q *Queue[ValueType]) Size() int { + return int(atomic.LoadInt32(&q.size)) +} + +func (q *Queue[ValueType]) Get() (val ValueType, ok bool) { +Start: + // fast path + if q.read != nil { + node := q.read + val = node.val + q.read = node.next + atomic.AddInt32(&q.size, -1) + + // reset node + node.reset() + q.nodePool.Put(node) + return val, true + } + + // slow path + q.lock() + if q.head == nil { + q.unlock() + return val, false + } + // single read + if q.head.next == nil { + node := q.head + val = node.val + q.head = nil + q.tail = nil + atomic.AddInt32(&q.size, -1) + q.unlock() + + // reset node + node.reset() + q.nodePool.Put(node) + return val, true + } + // transfer main linklist into q.read list and clear main linklist + q.read = q.head + q.head = nil + q.tail = nil + q.unlock() + goto Start +} + +func (q *Queue[ValueType]) Add(val ValueType) { + var node *linkNode[ValueType] + v := q.nodePool.Get() + if v == nil { + node = new(linkNode[ValueType]) + } else { + node = v.(*linkNode[ValueType]) + } + node.val = val + + q.lock() + if q.tail == nil { + q.head = node + q.tail = node + } else { + q.tail.next = node + q.tail = q.tail.next + } + atomic.AddInt32(&q.size, 1) + q.unlock() +} diff --git a/pkg/streamx/provider/ttstream/container/queue_test.go b/pkg/streamx/provider/ttstream/container/queue_test.go new file mode 100644 index 0000000000..b3b8b23f7a --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/queue_test.go @@ -0,0 +1,45 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package container + +import ( + "sync" + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestQueue(t *testing.T) { + q := NewQueue[int]() + round := 100000 + var wg sync.WaitGroup + wg.Add(1) + go func() { + for i := 0; i < round; i++ { + q.Add(1) + } + }() + sum := 0 + for sum < round { + v, ok := q.Get() + if ok { + sum += v + } + } + test.DeepEqual(t, sum, round) + test.DeepEqual(t, q.Size(), 0) +} diff --git a/pkg/streamx/provider/ttstream/container/stack.go b/pkg/streamx/provider/ttstream/container/stack.go new file mode 100644 index 0000000000..7083fc27e0 --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/stack.go @@ -0,0 +1,170 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package container + +import ( + "sync" +) + +func NewStack[ValueType any]() *Stack[ValueType] { + return &Stack[ValueType]{} +} + +type Stack[ValueType any] struct { + L sync.Mutex + head *doubleLinkNode[ValueType] // head will be protected by Locker + tail *doubleLinkNode[ValueType] // tail will be protected by Locker + size int + nodePool sync.Pool +} + +func (s *Stack[ValueType]) Size() (size int) { + s.L.Lock() + size = s.size + s.L.Unlock() + return size +} + +// RangeDelete range from the stack bottom +func (s *Stack[ValueType]) RangeDelete(checking func(v ValueType) (deleteNode bool, continueRange bool)) { + // Stop the world! + s.L.Lock() + // range from the stack bottom(oldest item) + node := s.head + deleteNode := false + continueRange := true + for node != nil && continueRange { + deleteNode, continueRange = checking(node.val) + if !deleteNode { + node = node.next + continue + } + // skip current node + last := node.last + next := node.next + // modify last node + if last != nil { + // change last next ptr + last.next = next + } + // modify next node + if next != nil { + next.last = last + } + // modify link list + if s.head == node { + s.head = next + } + if s.tail == node { + s.tail = last + } + node = node.next + s.size -= 1 + } + s.L.Unlock() +} + +func (s *Stack[ValueType]) pop() (node *doubleLinkNode[ValueType]) { + if s.tail == nil { + return nil + } + node = s.tail + if node.last == nil { + // if node is the only node in the list, clear the whole linklist + s.head = nil + s.tail = nil + } else { + // if node is not the only node in the list, only modify the list's tail + s.tail = node.last + s.tail.next = nil + } + s.size-- + return node +} + +func (s *Stack[ValueType]) Pop() (value ValueType, ok bool) { + s.L.Lock() + node := s.pop() + s.L.Unlock() + if node == nil { + return value, false + } + + value = node.val + node.reset() + s.nodePool.Put(node) + return value, true +} + +func (s *Stack[ValueType]) popBottom() (node *doubleLinkNode[ValueType]) { + if s.head == nil { + return nil + } + node = s.head + if node.next == nil { + // if node is the only node in the list, clear the whole linklist + s.head = nil + s.tail = nil + } else { + // if node is not the only node in the list, only modify the list's head + s.head = s.head.next + s.head.last = nil + } + s.size-- + return node +} + +func (s *Stack[ValueType]) PopBottom() (value ValueType, ok bool) { + s.L.Lock() + node := s.popBottom() + s.L.Unlock() + if node == nil { + return value, false + } + + value = node.val + node.reset() + s.nodePool.Put(node) + return value, true +} + +func (s *Stack[ValueType]) Push(value ValueType) { + var node *doubleLinkNode[ValueType] + v := s.nodePool.Get() + if v == nil { + node = &doubleLinkNode[ValueType]{} + } else { + node = v.(*doubleLinkNode[ValueType]) + } + node.val = value + node.next = nil + + s.L.Lock() + if s.tail == nil { + // first node + node.last = nil + s.head = node + s.tail = node + } else { + // not first node + node.last = s.tail + s.tail.next = node + s.tail = node + } + s.size++ + s.L.Unlock() +} diff --git a/pkg/streamx/provider/ttstream/container/stack_test.go b/pkg/streamx/provider/ttstream/container/stack_test.go new file mode 100644 index 0000000000..420bd636b5 --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/stack_test.go @@ -0,0 +1,99 @@ +package container + +import ( + "sync" + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestStack(t *testing.T) { + stk := NewStack[int]() + round := 100000 + var wg sync.WaitGroup + wg.Add(1) + go func() { + for i := 1; i <= round; i++ { + stk.Push(1) + } + }() + sum := 0 + var v int + var ok bool + for sum < round { + if sum%2 == 0 { + v, ok = stk.Pop() + } else { + v, ok = stk.PopBottom() + } + if ok { + sum += v + } + } + test.DeepEqual(t, sum, round) +} + +func TestStackOrder(t *testing.T) { + stk := NewStack[int]() + round := 100000 + for i := 1; i <= round; i++ { + stk.Push(i) + } + target := round + for { + v, ok := stk.Pop() + if !ok { + break + } + test.DeepEqual(t, v, target) + target-- + } + test.DeepEqual(t, target, 0) +} + +func TestStackPopBottomOrder(t *testing.T) { + stk := NewStack[int]() + round := 100000 + for i := 0; i < round; i++ { + stk.Push(i) + } + target := 0 + for { + v, ok := stk.PopBottom() + if !ok { + break + } + test.DeepEqual(t, v, target) + target++ + } + test.DeepEqual(t, target, round) +} + +func TestStackRangeDelete(t *testing.T) { + stk := NewStack[int]() + round := 1000 + for i := 1; i <= round; i++ { + stk.Push(i) + } + stk.RangeDelete(func(v int) (deleteNode bool, continueRange bool) { + return v%2 == 0, true + }) + test.Assert(t, stk.Size() == round/2, stk.Size()) + size := 0 + stk.RangeDelete(func(v int) (deleteNode bool, continueRange bool) { + size++ + return false, true + }) + test.Assert(t, size == round/2, size) + + size = 0 + for { + _, ok := stk.Pop() + if ok { + size++ + } else { + break + } + } + test.Assert(t, size == round/2, size) +} diff --git a/pkg/streamx/provider/ttstream/frame.go b/pkg/streamx/provider/ttstream/frame.go index 41309ee945..7732f079d1 100644 --- a/pkg/streamx/provider/ttstream/frame.go +++ b/pkg/streamx/provider/ttstream/frame.go @@ -1,13 +1,33 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( "context" "encoding/binary" "fmt" + "sync" + "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/gopkg/bufiox" - "github.com/cloudwego/gopkg/protocol/thrift" + gopkgthrift "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + "github.com/cloudwego/kitex/pkg/streamx" ) const ( @@ -24,6 +44,8 @@ var frameTypeToString = map[int32]string{ trailerFrameType: ttheader.FrameTypeTrailer, } +var framePool sync.Pool + type Frame struct { streamFrame meta IntHeader @@ -31,15 +53,31 @@ type Frame struct { payload []byte } -func newFrame(meta streamFrame, typ int32, payload []byte) Frame { - return Frame{ - streamFrame: meta, - typ: typ, - payload: payload, +func newFrame(sframe streamFrame, meta IntHeader, typ int32, payload []byte) (fr *Frame) { + v := framePool.Get() + if v == nil { + fr = new(Frame) + } else { + fr = v.(*Frame) } + fr.streamFrame = sframe + fr.meta = meta + fr.typ = typ + fr.payload = payload + return fr +} + +func recycleFrame(frame *Frame) { + frame.streamFrame = streamFrame{} + frame.meta = nil + frame.typ = 0 + frame.payload = nil + framePool.Put(frame) } -func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr Frame) (err error) { +// EncodeFrame will not call Flush! +func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr *Frame) (err error) { + written := writer.WrittenLen() param := ttheader.EncodeParam{ Flags: ttheader.HeaderFlagsStreaming, SeqID: fr.sid, @@ -65,66 +103,80 @@ func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr Frame) (err error return err } if len(fr.payload) > 0 { - _, err = writer.WriteBinary(fr.payload) + if nw, ok := writer.(gopkgthrift.NocopyWriter); ok { + err = nw.WriteDirect(fr.payload, 0) + } else { + _, err = writer.WriteBinary(fr.payload) + } if err != nil { return err } } - binary.BigEndian.PutUint32(totalLenField, uint32(writer.WrittenLen()-4)) - err = writer.Flush() - return err + written = writer.WrittenLen() - written + binary.BigEndian.PutUint32(totalLenField, uint32(written-4)) + return nil } -func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr Frame, err error) { +func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err error) { var dp ttheader.DecodeParam dp, err = ttheader.Decode(ctx, reader) if err != nil { return } - if dp.Flags != ttheader.HeaderFlagsStreaming { + if dp.Flags&ttheader.HeaderFlagsStreaming == 0 { err = fmt.Errorf("unexpected header flags: %d", dp.Flags) return } - fr.meta = dp.IntInfo - frtype := fr.meta[ttheader.FrameType] - switch frtype { + var ftype int32 + var fheader streamx.Header + var ftrailer streamx.Trailer + switch dp.IntInfo[ttheader.FrameType] { case ttheader.FrameTypeMeta: - fr.typ = metaFrameType + ftype = metaFrameType case ttheader.FrameTypeHeader: - fr.typ = headerFrameType - fr.header = dp.StrInfo + ftype = headerFrameType + fheader = dp.StrInfo case ttheader.FrameTypeData: - fr.typ = dataFrameType + ftype = dataFrameType case ttheader.FrameTypeTrailer: - fr.typ = trailerFrameType - fr.trailer = dp.StrInfo + ftype = trailerFrameType + ftrailer = dp.StrInfo default: - err = fmt.Errorf("unexpected frame type: %v", fr.meta[ttheader.FrameType]) + err = fmt.Errorf("unexpected frame type: %v", dp.IntInfo[ttheader.FrameType]) return } - // stream meta - fr.sid = dp.SeqID - fr.method = fr.meta[ttheader.ToMethod] + fmethod := dp.IntInfo[ttheader.ToMethod] + fsid := dp.SeqID // frame payload - if dp.PayloadLen == 0 { - return fr, nil - } - fr.payload = make([]byte, dp.PayloadLen) - _, err = reader.ReadBinary(fr.payload) - reader.Release(err) - if err != nil { - return + var fpayload []byte + if dp.PayloadLen > 0 { + fpayload = mcache.Malloc(dp.PayloadLen) + _, err = reader.ReadBinary(fpayload) // copy read + _ = reader.Release(err) + if err != nil { + return + } + } else { + _ = reader.Release(nil) } + + fr = newFrame( + streamFrame{sid: fsid, method: fmethod, header: fheader, trailer: ftrailer}, + dp.IntInfo, + ftype, fpayload, + ) return fr, nil } +var thriftCodec = thrift.NewThriftCodec() + func EncodePayload(ctx context.Context, msg any) ([]byte, error) { - return thrift.FastMarshal(msg.(thrift.FastCodec)), nil + payload, err := thrift.MarshalThriftData(ctx, thriftCodec, msg) + return payload, err } func DecodePayload(ctx context.Context, payload []byte, msg any) error { - err := thrift.FastUnmarshal(payload, msg.(thrift.FastCodec)) - return err + return thrift.UnmarshalThriftData(ctx, thriftCodec, "", payload, msg) } diff --git a/pkg/streamx/provider/ttstream/frame_handler.go b/pkg/streamx/provider/ttstream/frame_handler.go new file mode 100644 index 0000000000..033a64bde2 --- /dev/null +++ b/pkg/streamx/provider/ttstream/frame_handler.go @@ -0,0 +1,27 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + + "github.com/cloudwego/kitex/pkg/streamx" +) + +type HeaderFrameHandler interface { + OnStream(ctx context.Context) (IntHeader, streamx.Header, error) +} diff --git a/pkg/streamx/provider/ttstream/frame_test.go b/pkg/streamx/provider/ttstream/frame_test.go index dc0d4d29e0..19283d18c5 100644 --- a/pkg/streamx/provider/ttstream/frame_test.go +++ b/pkg/streamx/provider/ttstream/frame_test.go @@ -1,28 +1,54 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( + "bytes" "context" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/remote" ) func TestFrameCodec(t *testing.T) { - rw := remote.NewReaderWriterBuffer(1024) - + var buf bytes.Buffer + writer := bufiox.NewDefaultWriter(&buf) + reader := bufiox.NewDefaultReader(&buf) wframe := newFrame(streamFrame{ - sid: 1, + sid: 0, method: "method", header: map[string]string{"key": "value"}, - }, headerFrameType, []byte("hello world")) - err := EncodeFrame(context.Background(), rw, wframe) - test.Assert(t, err == nil, err) + }, nil, headerFrameType, []byte("hello world")) - rframe, err := DecodeFrame(context.Background(), rw) + for i := 0; i < 10; i++ { + wframe.sid = int32(i) + err := EncodeFrame(context.Background(), writer, wframe) + test.Assert(t, err == nil, err) + } + err := writer.Flush() test.Assert(t, err == nil, err) - test.DeepEqual(t, string(wframe.payload), string(rframe.payload)) - test.DeepEqual(t, wframe, rframe) + + for i := 0; i < 10; i++ { + rframe, err := DecodeFrame(context.Background(), reader) + test.Assert(t, err == nil, err) + test.DeepEqual(t, string(wframe.payload), string(rframe.payload)) + test.DeepEqual(t, wframe.header, rframe.header) + } } func TestFrameWithoutPayloadCodec(t *testing.T) { diff --git a/pkg/streamx/provider/ttstream/ktx/ktx.go b/pkg/streamx/provider/ttstream/ktx/ktx.go new file mode 100644 index 0000000000..bfe80fad7b --- /dev/null +++ b/pkg/streamx/provider/ttstream/ktx/ktx.go @@ -0,0 +1,71 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ktx + +import ( + "context" + "sync" +) + +type ktxKey struct{} + +var _ context.Context = (*ktx)(nil) + +func WithCancel(ctx context.Context) (context.Context, context.CancelFunc) { + k := new(ktx) + ctx = context.WithValue(ctx, ktxKey{}, k) + k.Context = ctx + return ctx, k.cancelFunc() +} + +func RegisterCancelCallback(ctx context.Context, callback func()) bool { + v := ctx.Value(ktxKey{}) + if v == nil { + return false + } + k, ok := v.(*ktx) + if !ok { + return false + } + k.registerCancelCallback(callback) + return true +} + +type ktx struct { + context.Context + locker sync.Mutex + callbacks []func() +} + +func (k *ktx) cancelFunc() context.CancelFunc { + return func() { + k.locker.Lock() + callbacks := k.callbacks + k.locker.Unlock() + + // run all callbacks + for _, callback := range callbacks { + callback() + } + } +} + +func (k *ktx) registerCancelCallback(callback func()) { + k.locker.Lock() + k.callbacks = append(k.callbacks, callback) + k.locker.Unlock() +} diff --git a/pkg/streamx/provider/ttstream/ktx/ktx_test.go b/pkg/streamx/provider/ttstream/ktx/ktx_test.go new file mode 100644 index 0000000000..247f8f4d15 --- /dev/null +++ b/pkg/streamx/provider/ttstream/ktx/ktx_test.go @@ -0,0 +1,41 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ktx + +import ( + "context" + "testing" +) + +func TestKtx(t *testing.T) { + ctx := context.Background() + + // server start + ctx, cancelFunc := WithCancel(ctx) + + // client call + var clientCanceled int32 + RegisterCancelCallback(ctx, func() { + clientCanceled++ + }) + + // server recv exception + cancelFunc() + if clientCanceled != 1 { + t.Fatal() + } +} diff --git a/pkg/streamx/provider/ttstream/meta_frame_handler.go b/pkg/streamx/provider/ttstream/meta_frame_handler.go index 2c193cf702..35c3f5f67c 100644 --- a/pkg/streamx/provider/ttstream/meta_frame_handler.go +++ b/pkg/streamx/provider/ttstream/meta_frame_handler.go @@ -1,6 +1,26 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream -import "sync" +import ( + "sync" + + "github.com/cloudwego/kitex/pkg/streamx" +) type StreamMeta interface { Meta() map[string]string @@ -9,7 +29,7 @@ type StreamMeta interface { } type MetaFrameHandler interface { - OnMetaFrame(smeta StreamMeta, intHeader IntHeader, header Header, payload []byte) error + OnMetaFrame(smeta StreamMeta, intHeader IntHeader, header streamx.Header, payload []byte) error } var _ StreamMeta = (*streamMeta)(nil) diff --git a/pkg/streamx/provider/ttstream/metadata.go b/pkg/streamx/provider/ttstream/metadata.go index 5d9f7973d5..fef0b03789 100644 --- a/pkg/streamx/provider/ttstream/metadata.go +++ b/pkg/streamx/provider/ttstream/metadata.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( @@ -6,10 +22,10 @@ import ( "github.com/cloudwego/kitex/pkg/streamx" ) -var ErrInvalidStreamKind = errors.New("invalid stream kind") - -type Header map[string]string -type Trailer map[string]string +var ( + ErrInvalidStreamKind = errors.New("invalid stream kind") + ErrClosedStream = errors.New("stream is closed") +) // only for meta frame handler type IntHeader map[uint16]string @@ -17,14 +33,14 @@ type IntHeader map[uint16]string // ClientStreamMeta cannot send header directly, should send from ctx type ClientStreamMeta interface { streamx.ClientStream - Header() (Header, error) - Trailer() (Trailer, error) + Header() (streamx.Header, error) + Trailer() (streamx.Trailer, error) } // ServerStreamMeta cannot read header directly, should read from ctx type ServerStreamMeta interface { streamx.ServerStream - SetHeader(hd Header) error - SendHeader(hd Header) error - SetTrailer(hd Trailer) error + SetHeader(hd streamx.Header) error + SendHeader(hd streamx.Header) error + SetTrailer(hd streamx.Trailer) error } diff --git a/pkg/streamx/provider/ttstream/mock_test.go b/pkg/streamx/provider/ttstream/mock_test.go index d289696a8b..a2d99b58dc 100644 --- a/pkg/streamx/provider/ttstream/mock_test.go +++ b/pkg/streamx/provider/ttstream/mock_test.go @@ -1,8 +1,25 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( "encoding/json" "fmt" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" kutils "github.com/cloudwego/kitex/pkg/utils" ) diff --git a/pkg/streamx/provider/ttstream/scavenger.go b/pkg/streamx/provider/ttstream/scavenger.go deleted file mode 100644 index a809627448..0000000000 --- a/pkg/streamx/provider/ttstream/scavenger.go +++ /dev/null @@ -1,42 +0,0 @@ -package ttstream - -import ( - "sync" - "time" -) - -type Object interface { - Available() bool - Close() error -} - -func newScavenger() *scavenger { - s := new(scavenger) - go s.Cleaning() - return s -} - -type scavenger struct { - sync.RWMutex - objects []Object -} - -func (s *scavenger) Add(o Object) { - s.Lock() - s.objects = append(s.objects, o) - s.Unlock() -} - -func (s *scavenger) Cleaning() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - for range ticker.C { - s.RLock() - for _, o := range s.objects { - if !o.Available() { - _ = o.Close() - } - } - s.RUnlock() - } -} diff --git a/pkg/streamx/provider/ttstream/server_option.go b/pkg/streamx/provider/ttstream/server_option.go index 5c8e1a9455..2e2b619096 100644 --- a/pkg/streamx/provider/ttstream/server_option.go +++ b/pkg/streamx/provider/ttstream/server_option.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream type ServerProviderOption func(pc *serverProvider) diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go index 57740e7e0b..5ecdfb26ca 100644 --- a/pkg/streamx/provider/ttstream/server_provider.go +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( @@ -8,10 +24,12 @@ import ( "github.com/cloudwego/gopkg/protocol/ttheader" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" "github.com/cloudwego/netpoll" ) type serverTransCtxKey struct{} +type serverStreamCancelCtxKey struct{} func NewServerProvider(sinfo *serviceinfo.ServiceInfo, opts ...ServerProviderOption) (streamx.ServerProvider, error) { sp := new(serverProvider) @@ -38,20 +56,16 @@ func (s serverProvider) Available(ctx context.Context, conn net.Conn) bool { } func (s serverProvider) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { - trans := newTransport(serverTransport, s.sinfo, conn.(netpoll.Connection)) + nconn := conn.(netpoll.Connection) + trans := newTransport(serverTransport, s.sinfo, nconn) + _ = nconn.(onDisConnectSetter).SetOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { + // server only close transport when peer connection closed + _ = trans.Close() + }) return context.WithValue(ctx, serverTransCtxKey{}, trans), nil } func (s serverProvider) OnInactive(ctx context.Context, conn net.Conn) (context.Context, error) { - trans, _ := ctx.Value(serverTransCtxKey{}).(*transport) - if trans == nil { - return ctx, nil - } - // server should Close transport - err := trans.Close() - if err != nil { - return nil, err - } return ctx, nil } @@ -60,22 +74,30 @@ func (s serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Co if trans == nil { return nil, nil, nil } - st, err := trans.readStream() + st, err := trans.readStream(ctx) if err != nil { return nil, nil, err } ctx = metainfo.SetMetaInfoFromMap(ctx, st.header) ss := newServerStream(st) + + ctx, cancelFunc := ktx.WithCancel(ctx) + ctx = context.WithValue(ctx, serverStreamCancelCtxKey{}, cancelFunc) return ctx, ss, nil } func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream) (context.Context, error) { sst := ss.(*serverStream) - if err := sst.sendTrailer(); err != nil { - return nil, err - } - if err := sst.close(); err != nil { - return nil, err + _ = sst.close() + + cancelFunc, _ := ctx.Value(serverStreamCancelCtxKey{}).(context.CancelFunc) + if cancelFunc != nil { + cancelFunc() } + return ctx, nil } + +type onDisConnectSetter interface { + SetOnDisconnect(onDisconnect netpoll.OnDisconnect) error +} diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 458e79ee95..20f0915627 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -1,56 +1,81 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( "context" "errors" + "fmt" "sync/atomic" + "time" - "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx" ) var ( - _ streamx.ClientStream = (*clientStream)(nil) - _ streamx.ServerStream = (*serverStream)(nil) - _ streamx.ClientStreamMetadata[Header, Trailer] = (*clientStream)(nil) - _ streamx.ServerStreamMetadata[Header, Trailer] = (*serverStream)(nil) - _ StreamMeta = (*stream)(nil) + _ streamx.ClientStream = (*clientStream)(nil) + _ streamx.ServerStream = (*serverStream)(nil) + _ streamx.ClientStreamMetadata = (*clientStream)(nil) + _ streamx.ServerStreamMetadata = (*serverStream)(nil) + _ StreamMeta = (*stream)(nil) ) -func newStream(trans *transport, mode streamx.StreamingMode, smeta streamFrame) (s *stream) { - s = new(stream) +func newStream(trans *transport, mode streamx.StreamingMode, smeta streamFrame) *stream { + s := new(stream) s.streamFrame = smeta s.trans = trans s.mode = mode - s.headerSig = make(chan struct{}) - s.trailerSig = make(chan struct{}) + s.wheader = make(streamx.Header) + s.wtrailer = make(streamx.Trailer) + s.headerSig = make(chan int32, 1) + s.trailerSig = make(chan int32, 1) s.StreamMeta = newStreamMeta() - trans.storeStreamIO(s) return s } type streamFrame struct { sid int32 method string - header Header // key:value, key is full name - trailer Trailer + header streamx.Header // key:value, key is full name + trailer streamx.Trailer } +const ( + streamSigNone int32 = 0 + streamSigActive int32 = 1 + streamSigInactive int32 = -1 +) + type stream struct { streamFrame trans *transport mode streamx.StreamingMode - wheader Header - wtrailer Trailer + wheader streamx.Header // wheader == nil means it already be sent + wtrailer streamx.Trailer // wtrailer == nil means it already be sent selfEOF int32 peerEOF int32 - headerSig chan struct{} - trailerSig chan struct{} + headerSig chan int32 + trailerSig chan int32 StreamMeta metaHandler MetaFrameHandler + recvTimeout time.Duration } func (s *stream) Mode() streamx.StreamingMode { @@ -68,32 +93,51 @@ func (s *stream) Method() string { return s.method } +func (s *stream) close() { + select { + case s.headerSig <- streamSigInactive: + default: + } + select { + case s.trailerSig <- streamSigInactive: + default: + } +} + func (s *stream) setMetaFrameHandler(h MetaFrameHandler) { s.metaHandler = h } -func (s *stream) readMetaFrame(intHeader IntHeader, header Header, payload []byte) (err error) { +func (s *stream) readMetaFrame(intHeader IntHeader, header streamx.Header, payload []byte) (err error) { if s.metaHandler == nil { return nil } return s.metaHandler.OnMetaFrame(s.StreamMeta, intHeader, header, payload) } -func (s *stream) readHeader(hd Header) (err error) { +func (s *stream) readHeader(hd streamx.Header) (err error) { s.header = hd select { - case <-s.headerSig: - return errors.New("already set header") + case s.headerSig <- streamSigActive: default: - close(s.headerSig) + return fmt.Errorf("stream[%d] already set header", s.sid) } klog.Debugf("stream[%s] read header: %v", s.method, hd) return nil } -func (s *stream) writeHeader(hd Header) (err error) { +// setHeader use the hd as the underlying header +func (s *stream) setHeader(hd streamx.Header) { + if hd != nil { + s.wheader = hd + } + return +} + +// writeHeader copy kvs into s.wheader +func (s *stream) writeHeader(hd streamx.Header) error { if s.wheader == nil { - s.wheader = make(Header) + return fmt.Errorf("stream header already sent") } for k, v := range hd { s.wheader[k] = v @@ -104,32 +148,38 @@ func (s *stream) writeHeader(hd Header) (err error) { func (s *stream) sendHeader() (err error) { wheader := s.wheader s.wheader = nil + if wheader == nil { + return fmt.Errorf("stream header already sent") + } err = s.trans.streamSendHeader(s.sid, s.method, wheader) return err } // readTrailer by client: unblock recv function and return EOF if no unread frame // readTrailer by server: unblock recv function and return EOF if no unread frame -func (s *stream) readTrailer(tl Trailer) (err error) { +func (s *stream) readTrailer(tl streamx.Trailer) (err error) { if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { - return nil + return fmt.Errorf("stream read a unexcept trailer") } s.trailer = tl select { - case <-s.trailerSig: + case s.trailerSig <- streamSigActive: + default: return errors.New("already set trailer") + } + select { + case s.headerSig <- streamSigNone: + // if trailer arrived, we should return unblock stream.Header() default: - close(s.trailerSig) } - klog.Debugf("stream[%d] recv trailer: %v", s.sid, tl) return s.trans.streamCloseRecv(s) } -func (s *stream) writeTrailer(tl Trailer) (err error) { +func (s *stream) writeTrailer(tl streamx.Trailer) (err error) { if s.wtrailer == nil { - s.wtrailer = make(Trailer) + return fmt.Errorf("stream trailer already sent") } for k, v := range tl { s.wtrailer[k] = v @@ -141,25 +191,43 @@ func (s *stream) sendTrailer() (err error) { if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { return nil } - klog.Debugf("stream[%d] send trialer", s.sid) - return s.trans.streamSendTrailer(s.sid, s.method, s.wtrailer) + wtrailer := s.wtrailer + s.wtrailer = nil + if wtrailer == nil { + return fmt.Errorf("stream trailer already sent") + } + klog.Debugf("transport[%d]-stream[%d] send trialer", s.trans.kind, s.sid) + return s.trans.streamCloseSend(s.sid, s.method, wtrailer) } -func (s *stream) SendMsg(ctx context.Context, res any) error { - payload, err := EncodePayload(ctx, res) - if err != nil { - return err +func (s *stream) finished() bool { + return atomic.LoadInt32(&s.peerEOF) == 1 && + atomic.LoadInt32(&s.selfEOF) == 1 +} + +func (s *stream) setRecvTimeout(timeout time.Duration) { + if timeout <= 0 { + return } - return s.trans.streamSend(s.sid, s.method, s.wheader, payload) + s.recvTimeout = timeout +} + +func (s *stream) SendMsg(ctx context.Context, res any) (err error) { + err = s.trans.streamSend(ctx, s.sid, s.method, s.wheader, res) + return err } func (s *stream) RecvMsg(ctx context.Context, req any) error { - payload, err := s.trans.streamRecv(s.sid) - if err != nil { - return err + if s.recvTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, s.recvTimeout) + defer cancel() } - err = DecodePayload(ctx, payload, req.(thrift.FastCodec)) - return err + return s.trans.streamRecv(ctx, s.sid, req) +} + +func (s *stream) cancel() { + _ = s.trans.streamCancel(s) } func newClientStream(s *stream) *clientStream { @@ -171,12 +239,12 @@ type clientStream struct { *stream } -func (s *clientStream) CloseSend(ctx context.Context) error { - return s.sendTrailer() +func (s *clientStream) RecvMsg(ctx context.Context, req any) error { + return s.stream.RecvMsg(ctx, req) } -func (s *clientStream) close() error { - return s.trans.streamClose(s.stream) +func (s *clientStream) CloseSend(ctx context.Context) error { + return s.sendTrailer() } func newServerStream(s *stream) streamx.ServerStream { @@ -188,8 +256,8 @@ type serverStream struct { *stream } -func (s *serverStream) close() error { - return s.trans.streamClose(s.stream) +func (s *serverStream) RecvMsg(ctx context.Context, req any) error { + return s.stream.RecvMsg(ctx, req) } // SendMsg should send left header first @@ -201,3 +269,16 @@ func (s *serverStream) SendMsg(ctx context.Context, res any) error { } return s.stream.SendMsg(ctx, res) } + +// close will be called after server handler returned +// after close stream cannot be access again +func (s *serverStream) close() error { + // write loop should help to delete stream + err := s.sendTrailer() + if err != nil { + return err + } + err = s.trans.streamDelete(s.sid) + s.stream.close() + return err +} diff --git a/pkg/streamx/provider/ttstream/stream_header_trailer.go b/pkg/streamx/provider/ttstream/stream_header_trailer.go index 5a91294dd9..beedb24f41 100644 --- a/pkg/streamx/provider/ttstream/stream_header_trailer.go +++ b/pkg/streamx/provider/ttstream/stream_header_trailer.go @@ -1,30 +1,67 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream +import ( + "errors" + + "github.com/cloudwego/kitex/pkg/streamx" +) + var _ ClientStreamMeta = (*clientStream)(nil) var _ ServerStreamMeta = (*serverStream)(nil) -func (s *clientStream) Header() (Header, error) { - <-s.headerSig - return s.header, nil +func (s *clientStream) Header() (streamx.Header, error) { + sig := <-s.headerSig + switch sig { + case streamSigActive: + return s.header, nil + case streamSigNone: + return make(streamx.Header), nil + case streamSigInactive: + return nil, ErrClosedStream + } + return nil, errors.New("invalid stream signal") } -func (s *clientStream) Trailer() (Trailer, error) { - <-s.trailerSig - return s.trailer, nil +func (s *clientStream) Trailer() (streamx.Trailer, error) { + sig := <-s.trailerSig + switch sig { + case streamSigActive: + return s.trailer, nil + case streamSigNone: + return make(streamx.Trailer), nil + case streamSigInactive: + return nil, ErrClosedStream + } + return nil, errors.New("invalid stream signal") } -func (s *serverStream) SetHeader(hd Header) error { +func (s *serverStream) SetHeader(hd streamx.Header) error { return s.writeHeader(hd) } -func (s *serverStream) SendHeader(hd Header) error { - err := s.writeHeader(hd) - if err != nil { +func (s *serverStream) SendHeader(hd streamx.Header) error { + if err := s.writeHeader(hd); err != nil { return err } return s.stream.sendHeader() } -func (s *serverStream) SetTrailer(tl Trailer) error { +func (s *serverStream) SetTrailer(tl streamx.Trailer) error { return s.writeTrailer(tl) } diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go index 4680e09850..a03cc6c0ea 100644 --- a/pkg/streamx/provider/ttstream/stream_io.go +++ b/pkg/streamx/provider/ttstream/stream_io.go @@ -1,48 +1,91 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( + "context" + "errors" "io" - "sync" + "sync/atomic" + + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" ) type streamIO struct { - stream *stream - cond *sync.Cond - frames []Frame // TODO: using link list - end bool + ctx context.Context + trigger chan struct{} + stream *stream + // eofFlag == 2 when both parties send trailers + eofFlag int32 + // eofCallback will be called when eofFlag == 2 + // eofCallback will not be called if stream is not be ended in a normal way + eofCallback func() + fpipe *container.Pipe[*Frame] + fcache [1]*Frame +} + +func newStreamIO(ctx context.Context, s *stream) *streamIO { + sio := new(streamIO) + sio.ctx = ctx + sio.trigger = make(chan struct{}) + sio.stream = s + sio.fpipe = container.NewPipe[*Frame]() + return sio } -func newStreamIO(s *stream) *streamIO { - var lock sync.Mutex - var cond = sync.NewCond(&lock) - return &streamIO{stream: s, cond: cond} +func (s *streamIO) setEOFCallback(f func()) { + s.eofCallback = f } -func (s *streamIO) input(f Frame) { - s.cond.L.Lock() - s.frames = append(s.frames, f) - s.cond.L.Unlock() - s.cond.Signal() +func (s *streamIO) input(ctx context.Context, f *Frame) { + err := s.fpipe.Write(ctx, f) + if err != nil { + klog.Errorf("fpipe write failed: %v", err) + } } -func (s *streamIO) output() (f Frame, err error) { - s.cond.L.Lock() - for len(s.frames) == 0 && !s.end { - s.cond.Wait() +func (s *streamIO) output(ctx context.Context) (f *Frame, err error) { + n, err := s.fpipe.Read(ctx, s.fcache[:]) + if err != nil { + if errors.Is(err, container.ErrPipeEOF) { + return nil, io.EOF + } + return nil, err + } + if n == 0 { + return nil, io.EOF } - // have incoming frames or eof - if len(s.frames) == 0 && s.end { - return f, io.EOF + return s.fcache[0], nil +} + +func (s *streamIO) closeRecv() { + s.fpipe.Close() + if atomic.AddInt32(&s.eofFlag, 1) == 2 && s.eofCallback != nil { + s.eofCallback() + } +} + +func (s *streamIO) closeSend() { + if atomic.AddInt32(&s.eofFlag, 1) == 2 && s.eofCallback != nil { + s.eofCallback() } - f = s.frames[0] - s.frames = s.frames[1:] - s.cond.L.Unlock() - return f, nil } -func (s *streamIO) eof() { - s.cond.L.Lock() - s.end = true - s.cond.L.Unlock() - s.cond.Signal() +func (s *streamIO) cancel() { + s.fpipe.Cancel() } diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index f35a1c08f2..eb03cd458f 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -1,7 +1,24 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( "context" + "encoding/binary" "errors" "fmt" "io" @@ -9,67 +26,104 @@ import ( "sync/atomic" "time" - "github.com/cloudwego/gopkg/bufiox" + "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" "github.com/cloudwego/netpoll" ) const ( clientTransport int32 = 1 serverTransport int32 = 2 + + streamCacheSize = 32 + frameChanSize = 32 ) -var _ Object = (*transport)(nil) +func isIgnoreError(err error) bool { + return errors.Is(err, netpoll.ErrEOF) || errors.Is(err, io.EOF) || errors.Is(err, netpoll.ErrConnClosed) +} type transport struct { - kind int32 - sinfo *serviceinfo.ServiceInfo - conn netpoll.Connection - reader bufiox.Reader - writer bufiox.Writer - streams sync.Map // key=streamID val=streamIO - sch chan *stream // in-coming stream channel - wch chan Frame // out-coming frame channel - stop chan struct{} - - // for scavenger check - lastActive atomic.Value // time.Time + kind int32 + sinfo *serviceinfo.ServiceInfo + conn netpoll.Connection + streams sync.Map // key=streamID val=streamIO + scache []*stream // size is streamCacheSize + spipe *container.Pipe[*stream] // in-coming stream channel + wchannel chan *Frame + closed chan struct{} + closedFlag int32 + streamingFlag int32 // flag == 0 means there is no active stream on transport } func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection) *transport { - _ = conn.SetDeadline(time.Now().Add(time.Hour)) - reader := bufiox.NewDefaultReader(conn) - writer := bufiox.NewDefaultWriter(conn) + // stream max idle session is 10 minutes. + // TODO: let it configurable + _ = conn.SetReadTimeout(time.Minute * 10) t := &transport{ - kind: kind, - sinfo: sinfo, - conn: conn, - reader: reader, - writer: writer, - streams: sync.Map{}, - sch: make(chan *stream, 8), - wch: make(chan Frame, 8), - stop: make(chan struct{}), + kind: kind, + sinfo: sinfo, + conn: conn, + streams: sync.Map{}, + spipe: container.NewPipe[*stream](), + scache: make([]*stream, 0, streamCacheSize), + wchannel: make(chan *Frame, frameChanSize), + closed: make(chan struct{}), } go func() { err := t.loopRead() - if err != nil && errors.Is(err, io.EOF) { - klog.Warnf("trans loop read err: %v", err) + if err != nil { + if !isIgnoreError(err) { + klog.Warnf("transport[%d] loop read err: %v", t.kind, err) + } + // if connection is closed by peer, loop read should return ErrConnClosed error, + // so we should close transport here + _ = t.Close() } }() go func() { err := t.loopWrite() - if err != nil && errors.Is(err, io.EOF) { - klog.Warnf("trans loop write err: %v", err) - return + if err != nil { + if !isIgnoreError(err) { + klog.Warnf("transport[%d] loop write err: %v", t.kind, err) + } + _ = t.Close() } }() return t } -func (t *transport) storeStreamIO(s *stream) { - t.streams.Store(s.sid, newStreamIO(s)) +// Close will close transport and destroy all resource and goroutines +// server close transport when connection is disconnected +// client close transport when transPool discard the transport +func (t *transport) Close() (err error) { + if !atomic.CompareAndSwapInt32(&t.closedFlag, 0, 1) { + return nil + } + close(t.closed) + klog.Debugf("transport[%s] is closing", t.conn.LocalAddr()) + t.spipe.Close() + t.streams.Range(func(key, value any) bool { + sio := value.(*streamIO) + sio.stream.close() + _ = t.streamDelete(sio.stream.sid) + return true + }) + return err +} + +func (t *transport) IsActive() bool { + return atomic.LoadInt32(&t.closedFlag) == 0 && t.conn.IsActive() +} + +func (t *transport) storeStreamIO(ctx context.Context, s *stream) *streamIO { + sio := newStreamIO(ctx, s) + t.streams.Store(s.sid, sio) + return sio } func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { @@ -82,21 +136,34 @@ func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { } func (t *transport) loopRead() error { + addr := t.conn.RemoteAddr().String() + if t.kind == clientTransport { + addr = t.conn.LocalAddr().String() + } for { - now := time.Now() - t.lastActive.Store(now) - // decode frame - fr, err := DecodeFrame(context.Background(), t.reader) + sizeBuf, err := t.conn.Reader().Peek(4) if err != nil { return err } + size := binary.BigEndian.Uint32(sizeBuf) + slice, err := t.conn.Reader().Slice(int(size + 4)) + if err != nil { + return err + } + reader := newReaderBuffer(slice) + fr, err := DecodeFrame(context.Background(), reader) + if err != nil { + return err + } + klog.Debugf("transport[%d-%s] DecodeFrame: fr=%v", t.kind, addr, fr) switch fr.typ { case metaFrameType: sio, ok := t.loadStreamIO(fr.sid) if !ok { - return fmt.Errorf("transport[%d] read a unknown stream meta: sid=%d", t.kind, fr.sid) + klog.Errorf("transport[%d-%s] read a unknown stream meta: sid=%d", t.kind, addr, fr.sid) + continue } err = sio.stream.readMetaFrame(fr.meta, fr.header, fr.payload) if err != nil { @@ -108,13 +175,15 @@ func (t *transport) loopRead() error { // Header Frame: server recv a new stream smode := t.sinfo.MethodInfo(fr.method).StreamingMode() s := newStream(t, smode, fr.streamFrame) - klog.Debugf("transport[%d] read a new stream: sid=%d", t.kind, s.sid) - t.sch <- s + t.storeStreamIO(context.Background(), s) + t.spipe.Write(context.Background(), s) case clientTransport: // Header Frame: client recv header sio, ok := t.loadStreamIO(fr.sid) if !ok { - return fmt.Errorf("transport[%d] read a unknown stream header: sid=%d", t.kind, fr.sid) + klog.Errorf("transport[%d-%s] read a unknown stream header: sid=%d header=%v", + t.kind, addr, fr.sid, fr.header) + continue } err = sio.stream.readHeader(fr.header) if err != nil { @@ -125,14 +194,20 @@ func (t *transport) loopRead() error { // Data Frame: decode and distribute data sio, ok := t.loadStreamIO(fr.sid) if !ok { - return fmt.Errorf("transport[%d] read a unknown stream data: sid=%d", t.kind, fr.sid) + klog.Errorf("transport[%d-%s] read a unknown stream data: sid=%d", t.kind, addr, fr.sid) + continue } - sio.input(fr) + sio.input(context.Background(), fr) case trailerFrameType: // Trailer Frame: recv trailer, Close read direction sio, ok := t.loadStreamIO(fr.sid) if !ok { - return fmt.Errorf("transport[%d] read a unknown stream trailer: sid=%d", t.kind, fr.sid) + // client recv an unknown trailer is in exception, + // because the client stream may already be GCed, + // but the connection is still active so peer server can send a trailer + klog.Errorf("transport[%d-%s] read a unknown stream trailer: sid=%d trailer=%v", + t.kind, addr, fr.sid, fr.trailer) + continue } if err = sio.stream.readTrailer(fr.trailer); err != nil { return err @@ -141,90 +216,110 @@ func (t *transport) loopRead() error { } } -func (t *transport) writeFrame(frame Frame) error { - err := EncodeFrame(context.Background(), t.writer, frame) - return err -} - -func (t *transport) loopWrite() error { +func (t *transport) loopWrite() (err error) { + defer func() { + // loop write should help to close connection + _ = t.conn.Close() + }() + writer := newWriterBuffer(t.conn.Writer()) + delay := 0 + // Important note: + // loopWrite may cannot find stream by sid since it may send trailer and delete sid from streams for { - now := time.Now() - t.lastActive.Store(now) - select { - case <-t.stop: - // re-check wch queue + case <-t.closed: + return nil + case fr, ok := <-t.wchannel: + if !ok { + // closed + return nil + } select { - case frame := <-t.wch: - if err := t.writeFrame(frame); err != nil { - return err - } - default: + case <-t.closed: + // double check closed return nil + default: } - case frame := <-t.wch: - if err := t.writeFrame(frame); err != nil { + + if err = EncodeFrame(context.Background(), writer, fr); err != nil { return err } + if delay >= 8 || len(t.wchannel) == 0 { + delay = 0 + if err = t.conn.Writer().Flush(); err != nil { + return err + } + } else { + delay++ + } } } } -func (t *transport) Available() bool { - v := t.lastActive.Load() - if v == nil { - return true - } - lastActive := v.(time.Time) - // let unavailable time configurable - return time.Now().Sub(lastActive) < time.Minute*10 -} - -func (t *transport) Close() (err error) { - select { - case <-t.stop: - default: - klog.Warnf("transport[%s] is closing", t.conn.LocalAddr()) - close(t.stop) - t.conn.Close() +// writeFrame is concurrent safe +func (t *transport) writeFrame(sframe streamFrame, meta IntHeader, ftype int32, data any) (err error) { + var payload []byte + if data != nil { + // payload should be written nocopy + payload, err = EncodePayload(context.Background(), data) + if err != nil { + return err + } } + frame := newFrame(sframe, meta, ftype, payload) + t.wchannel <- frame return nil } -func (t *transport) streamSend(sid int32, method string, wheader Header, payload []byte) (err error) { +func (t *transport) streamSend(ctx context.Context, sid int32, method string, wheader streamx.Header, res any) (err error) { if len(wheader) > 0 { err = t.streamSendHeader(sid, method, wheader) if err != nil { return err } } - f := newFrame(streamFrame{sid: sid, method: method}, dataFrameType, payload) - t.wch <- f - return nil + return t.writeFrame( + streamFrame{sid: sid, method: method}, + nil, dataFrameType, res, + ) } -func (t *transport) streamSendHeader(sid int32, method string, header Header) (err error) { - f := newFrame(streamFrame{sid: sid, method: method, header: header}, headerFrameType, []byte{}) - t.wch <- f - return nil +func (t *transport) streamSendHeader(sid int32, method string, header streamx.Header) (err error) { + return t.writeFrame( + streamFrame{sid: sid, method: method, header: header}, + nil, headerFrameType, nil) } -func (t *transport) streamSendTrailer(sid int32, method string, trailer Trailer) (err error) { - f := newFrame(streamFrame{sid: sid, method: method, trailer: trailer}, trailerFrameType, []byte{}) - t.wch <- f +func (t *transport) streamCloseSend(sid int32, method string, trailer streamx.Trailer) (err error) { + err = t.writeFrame( + streamFrame{sid: sid, method: method, trailer: trailer}, + nil, trailerFrameType, nil, + ) + if err != nil { + return err + } + sio, ok := t.loadStreamIO(sid) + if !ok { + return nil + } + sio.closeSend() return nil } -func (t *transport) streamRecv(sid int32) (payload []byte, err error) { +func (t *transport) streamRecv(ctx context.Context, sid int32, data any) (err error) { sio, ok := t.loadStreamIO(sid) if !ok { - return nil, io.EOF + return io.EOF } - f, err := sio.output() + f, err := sio.output(ctx) if err != nil { - return nil, err + return err } - return f.payload, nil + err = DecodePayload(context.Background(), f.payload, data.(thrift.FastCodec)) + // payload will not be access after decode + mcache.Free(f.payload) + recycleFrame(f) + return nil } func (t *transport) streamCloseRecv(s *stream) (err error) { @@ -232,52 +327,83 @@ func (t *transport) streamCloseRecv(s *stream) (err error) { if !ok { return fmt.Errorf("stream not found in stream map: sid=%d", s.sid) } - sio.eof() + sio.closeRecv() + return nil +} + +func (t *transport) streamCancel(s *stream) (err error) { + sio, ok := t.loadStreamIO(s.sid) + if !ok { + return fmt.Errorf("stream not found in stream map: sid=%d", s.sid) + } + sio.cancel() return nil } -func (t *transport) streamClose(s *stream) (err error) { +func (t *transport) streamDelete(sid int32) (err error) { // remove stream from transport - t.streams.Delete(s.sid) + _, ok := t.streams.LoadAndDelete(sid) + if !ok { + return nil + } + atomic.AddInt32(&t.streamingFlag, -1) return nil } +func (t *transport) IsStreaming() bool { + return atomic.LoadInt32(&t.streamingFlag) > 0 +} + var clientStreamID int32 -// newStream create new stream on current connection +// newStreamIO create new stream on current connection // it's typically used by client side -// newStream is concurrency safe -func (t *transport) newStream( - ctx context.Context, method string, header map[string]string) (*stream, error) { +// newStreamIO is concurrency safe +func (t *transport) newStreamIO( + ctx context.Context, method string, intHeader IntHeader, strHeader streamx.Header) (*streamIO, error) { if t.kind != clientTransport { return nil, fmt.Errorf("transport already be used as other kind") } + sid := atomic.AddInt32(&clientStreamID, 1) smode := t.sinfo.MethodInfo(method).StreamingMode() - smeta := streamFrame{ - sid: sid, - method: method, - header: header, + // create stream + err := t.writeFrame( + streamFrame{sid: sid, method: method, header: strHeader}, + intHeader, headerFrameType, nil, + ) + if err != nil { + return nil, err } - f := newFrame(smeta, headerFrameType, []byte{}) - s := newStream(t, smode, smeta) - t.wch <- f // create stream - return s, nil + s := newStream(t, smode, streamFrame{sid: sid, method: method}) + sio := t.storeStreamIO(ctx, s) + atomic.AddInt32(&t.streamingFlag, 1) + return sio, nil } // readStream wait for a new incoming stream on current connection // it's typically used by server side -func (t *transport) readStream() (*stream, error) { +func (t *transport) readStream(ctx context.Context) (*stream, error) { if t.kind != serverTransport { return nil, fmt.Errorf("transport already be used as other kind") } - select { - case <-t.stop: - return nil, io.EOF - case s := <-t.sch: - if s == nil { +READ: + if len(t.scache) > 0 { + s := t.scache[len(t.scache)-1] + t.scache = t.scache[:len(t.scache)-1] + atomic.AddInt32(&t.streamingFlag, 1) + return s, nil + } + n, err := t.spipe.Read(ctx, t.scache[0:streamCacheSize]) + if err != nil { + if errors.Is(err, container.ErrPipeEOF) { return nil, io.EOF } - return s, nil + return nil, err + } + if n == 0 { + panic("Assert: N == 0 !") } + t.scache = t.scache[:n] + goto READ } diff --git a/pkg/streamx/provider/ttstream/transport_buffer.go b/pkg/streamx/provider/ttstream/transport_buffer.go new file mode 100644 index 0000000000..a51a47c64d --- /dev/null +++ b/pkg/streamx/provider/ttstream/transport_buffer.go @@ -0,0 +1,132 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "sync" + + "github.com/cloudwego/gopkg/bufiox" + gopkgthrift "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/netpoll" +) + +var _ bufiox.Reader = (*readerBuffer)(nil) +var _ bufiox.Writer = (*writerBuffer)(nil) +var _ gopkgthrift.NocopyWriter = (*writerBuffer)(nil) + +var readerBufferPool sync.Pool +var writerBufferPool sync.Pool + +func newReaderBuffer(reader netpoll.Reader) (rb *readerBuffer) { + if v := readerBufferPool.Get(); v != nil { + rb = v.(*readerBuffer) + } else { + rb = new(readerBuffer) + } + rb.reader = reader + rb.readSize = 0 + return rb +} + +type readerBuffer struct { + reader netpoll.Reader + readSize int +} + +func (c *readerBuffer) Next(n int) (p []byte, err error) { + p, err = c.reader.Next(n) + c.readSize += len(p) + return p, err +} + +func (c *readerBuffer) ReadBinary(bs []byte) (n int, err error) { + n = len(bs) + buf, err := c.reader.Next(n) + if err != nil { + return 0, err + } + copy(bs, buf) + c.readSize += n + return n, nil +} + +func (c *readerBuffer) Peek(n int) (buf []byte, err error) { + return c.reader.Peek(n) +} + +func (c *readerBuffer) Skip(n int) (err error) { + err = c.reader.Skip(n) + if err != nil { + return err + } + c.readSize += n + return nil +} + +func (c *readerBuffer) ReadLen() (n int) { + return c.readSize +} + +func (c *readerBuffer) Release(e error) (err error) { + c.readSize = 0 + return c.reader.Release() +} + +func newWriterBuffer(writer netpoll.Writer) (wb *writerBuffer) { + if v := writerBufferPool.Get(); v != nil { + wb = v.(*writerBuffer) + } else { + wb = new(writerBuffer) + } + wb.writer = writer + wb.writeSize = 0 + return wb +} + +type writerBuffer struct { + writer netpoll.Writer + writeSize int +} + +func (c *writerBuffer) Malloc(n int) (buf []byte, err error) { + c.writeSize += n + return c.writer.Malloc(n) +} + +func (c *writerBuffer) WriteBinary(bs []byte) (n int, err error) { + n, err = c.writer.WriteBinary(bs) + c.writeSize += n + return n, err +} + +func (c *writerBuffer) WriteDirect(b []byte, remainCap int) (err error) { + err = c.writer.WriteDirect(b, remainCap) + c.writeSize += len(b) + return err +} + +func (c *writerBuffer) WrittenLen() (length int) { + return c.writeSize +} + +func (c *writerBuffer) Flush() (err error) { + err = c.writer.Flush() + c.writer = nil + c.writeSize = 0 + writerBufferPool.Put(c) + return err +} diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index f03b4b3741..ec85dd393c 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream import ( @@ -51,7 +67,7 @@ func TestTransport(t *testing.T) { go func() { for { - s, err := trans.readStream() + s, err := trans.readStream(ctx) t.Logf("OnRead read stream: %v, %v", s, err) if err != nil { if err == io.EOF { @@ -63,7 +79,7 @@ func TestTransport(t *testing.T) { go func(st streamx.ServerStream) { defer func() { // set trailer - err := st.(ServerStreamMeta).SetTrailer(Trailer{"key": "val"}) + err := st.(ServerStreamMeta).SetTrailer(streamx.Trailer{"key": "val"}) test.Assert(t, err == nil, err) // send trailer @@ -73,7 +89,7 @@ func TestTransport(t *testing.T) { }() // send header - err := st.(ServerStreamMeta).SendHeader(Header{"key": "val"}) + err := st.(ServerStreamMeta).SendHeader(streamx.Header{"key": "val"}) test.Assert(t, err == nil, err) // send data @@ -81,7 +97,7 @@ func TestTransport(t *testing.T) { req := new(TestRequest) err := st.RecvMsg(ctx, req) if errors.Is(err, io.EOF) { - t.Logf("server stream eof") + t.Logf("server stream closeRecv") return } test.Assert(t, err == nil, err) @@ -128,10 +144,10 @@ func TestTransport(t *testing.T) { defer wg.Done() // send header - s, err := trans.newStream(ctx, method, map[string]string{}) + sio, err := trans.newStreamIO(ctx, method, IntHeader{}, map[string]string{}) test.Assert(t, err == nil, err) - cs := newClientStream(s) + cs := newClientStream(sio.stream) t.Logf("client stream[%d] created", sid) // recv header diff --git a/pkg/streamx/provider/ttstream/ttstream_client_test.go b/pkg/streamx/provider/ttstream/ttstream_client_test.go index 2aedc8f488..4aa3bbe9bb 100644 --- a/pkg/streamx/provider/ttstream/ttstream_client_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_client_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream_test import ( @@ -16,6 +32,7 @@ import ( "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/streamxclient" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" @@ -25,11 +42,23 @@ import ( "github.com/cloudwego/netpoll" ) +func init() { + klog.SetLevel(klog.LevelDebug) +} + +func testHeaderAndTrailer(t *testing.T, stream streamx.ClientStreamMetadata) { + hd, err := stream.Header() + test.Assert(t, err == nil, err) + test.Assert(t, hd[headerKey] == headerVal, hd) + tl, err := stream.Trailer() + test.Assert(t, err == nil, err) + test.Assert(t, tl[trailerKey] == trailerVal, tl) +} + func TestTTHeaderStreaming(t *testing.T) { go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) }() - var addr = test.GetLocalAddress() ln, err := netpoll.CreateListener("tcp", addr) test.Assert(t, err == nil, err) @@ -43,9 +72,8 @@ func TestTTHeaderStreaming(t *testing.T) { time.Sleep(time.Millisecond * 100) } } - methodCount := map[string]int{} - serverRecvCount := map[string]int{} - serverSendCount := map[string]int{} + var serverRecvCount int32 + var serverSendCount int32 svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) // register pingpong service err = svr.RegisterService(pingpongServiceInfo, new(pingpongService)) @@ -61,7 +89,7 @@ func TestTTHeaderStreaming(t *testing.T) { return func(ctx context.Context, stream streamx.Stream, res any) (err error) { err = next(ctx, stream, res) if err == nil { - serverRecvCount[stream.Method()]++ + atomic.AddInt32(&serverRecvCount, 1) } else { log.Printf("server recv middleware err=%v", err) } @@ -72,7 +100,7 @@ func TestTTHeaderStreaming(t *testing.T) { return func(ctx context.Context, stream streamx.Stream, req any) (err error) { err = next(ctx, stream, req) if err == nil { - serverSendCount[stream.Method()]++ + atomic.AddInt32(&serverSendCount, 1) } else { log.Printf("server send middleware err=%v", err) } @@ -116,7 +144,6 @@ func TestTTHeaderStreaming(t *testing.T) { test.Assert(t, resArgs.Res() == nil) } test.Assert(t, err == nil, err) - methodCount[streamArgs.Stream().Method()]++ log.Printf("Server handler end") log.Printf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", @@ -143,8 +170,14 @@ func TestTTHeaderStreaming(t *testing.T) { client.WithPayloadCodec(thrift.NewThriftCodecWithConfig(thrift.FastRead|thrift.FastWrite|thrift.EnableSkipDecoder)), ) test.Assert(t, err == nil, err) + // create streaming client + cp, _ := ttstream.NewClientProvider( + streamingServiceInfo, + ttstream.WithClientLongConnPool(ttstream.DefaultLongConnConfig), + ) streamClient, err := NewStreamingClient( "kitex.service.streaming", + streamxclient.WithProvider(cp), streamxclient.WithHostPorts(addr), streamxclient.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { return func(ctx context.Context, stream streamx.Stream, res any) (err error) { @@ -212,10 +245,12 @@ func TestTTHeaderStreaming(t *testing.T) { test.Assert(t, err == nil, err) test.Assert(t, req.Type == res.Type, res.Type) test.Assert(t, req.Message == res.Message, res.Message) - test.Assert(t, serverRecvCount["Unary"] == 1, serverRecvCount) - test.Assert(t, serverSendCount["Unary"] == 1, serverSendCount) + test.Assert(t, serverRecvCount == 1, serverRecvCount) + test.Assert(t, serverSendCount == 1, serverSendCount) atomic.AddInt32(&serverStreamCount, -1) waitServerStreamDone() + serverRecvCount = 0 + serverSendCount = 0 // client stream round := 5 @@ -235,9 +270,12 @@ func TestTTHeaderStreaming(t *testing.T) { t.Logf("Client ClientStream CloseAndRecv: %v", res) atomic.AddInt32(&serverStreamCount, -1) waitServerStreamDone() - test.Assert(t, serverRecvCount["ClientStream"] == round, serverRecvCount) - test.Assert(t, serverSendCount["ClientStream"] == 1, serverSendCount) + test.Assert(t, serverRecvCount == int32(round), serverRecvCount) + test.Assert(t, serverSendCount == 1, serverSendCount) + testHeaderAndTrailer(t, cs) cs = nil + serverRecvCount = 0 + serverSendCount = 0 runtime.GC() // server stream @@ -246,12 +284,6 @@ func TestTTHeaderStreaming(t *testing.T) { req.Message = "ServerStream" ss, err := streamClient.ServerStream(ctx, req) test.Assert(t, err == nil, err) - // server stream recv header - hd, err := ss.Header() - test.Assert(t, err == nil, err) - t.Logf("Client ServerStream recv header: %v", hd) - test.DeepEqual(t, hd["key1"], "val1") - test.DeepEqual(t, hd["key2"], "val2") received := 0 for { res, err := ss.Recv(ctx) @@ -264,59 +296,275 @@ func TestTTHeaderStreaming(t *testing.T) { } err = ss.CloseSend(ctx) test.Assert(t, err == nil, err) - // server stream recv trailer - tl, err := ss.Trailer() - test.Assert(t, err == nil, err) - t.Logf("Client ServerStream recv trailer: %v", tl) - test.DeepEqual(t, tl["key1"], "val1") - test.DeepEqual(t, tl["key2"], "val2") atomic.AddInt32(&serverStreamCount, -1) waitServerStreamDone() - test.Assert(t, serverRecvCount["ServerStream"] == 1, serverRecvCount) - test.Assert(t, serverSendCount["ServerStream"] == received, serverSendCount) + test.Assert(t, serverRecvCount == 1, serverRecvCount) + test.Assert(t, serverSendCount == int32(received), serverSendCount, received) + testHeaderAndTrailer(t, ss) ss = nil + serverRecvCount = 0 + serverSendCount = 0 runtime.GC() // bidi stream - round = 5 t.Logf("=== BidiStream ===") - bs, err := streamClient.BidiStream(ctx) - test.Assert(t, err == nil, err) + concurrent := 1 + round = 5 + for c := 0; c < concurrent; c++ { + atomic.AddInt32(&serverStreamCount, -1) + go func() { + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + msg := "BidiStream" + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < round; i++ { + req := new(Request) + req.Message = msg + err := bs.Send(ctx, req) + test.Assert(t, err == nil, err) + } + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + }() + go func() { + defer wg.Done() + i := 0 + for { + res, err := bs.Recv(ctx) + t.Log(res, err) + if errors.Is(err, io.EOF) { + break + } + i++ + test.Assert(t, err == nil, err) + test.Assert(t, msg == res.Message, res.Message) + } + test.Assert(t, i == round, i) + }() + testHeaderAndTrailer(t, bs) + }() + } + waitServerStreamDone() + test.Assert(t, serverRecvCount == int32(concurrent*round), serverRecvCount) + test.Assert(t, serverSendCount == int32(concurrent*round), serverSendCount) + serverRecvCount = 0 + serverSendCount = 0 + runtime.GC() + + streamClient = nil +} + +func TestTTHeaderStreamingLongConn(t *testing.T) { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + + var addr = test.GetLocalAddress() + ln, _ := netpoll.CreateListener("tcp", addr) + defer ln.Close() + + // create server + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + // register streamingService as ttstreaam provider + sp, _ := ttstream.NewServerProvider(streamingServiceInfo) + _ = svr.RegisterService( + streamingServiceInfo, + new(streamingService), + streamxserver.WithProvider(sp), + ) + go func() { + _ = svr.Run() + }() + defer svr.Stop() + test.WaitServerStart(addr) + + numGoroutine := runtime.NumGoroutine() + cp, _ := ttstream.NewClientProvider( + streamingServiceInfo, + ttstream.WithClientLongConnPool( + ttstream.LongConnConfig{MaxIdleTimeout: time.Second}, + ), + ) + streamClient, _ := NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + streamxclient.WithProvider(cp), + ) + ctx := context.Background() msg := "BidiStream" + + t.Logf("checking only one connection be reused") var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - for i := 0; i < round; i++ { + for i := 0; i < 12; i++ { + wg.Add(1) + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + req := new(Request) + req.Message = string(make([]byte, 1024)) + err = bs.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err := bs.Recv(ctx) + test.Assert(t, err == nil, err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == req.Message, res.Message) + runtime.SetFinalizer(bs, func(_ any) { + wg.Done() + t.Logf("stream is finalized") + }) + bs = nil + runtime.GC() + wg.Wait() + } + + t.Logf("checking goroutines destroy") + // checking streaming goroutines + streams := 500 + for i := 0; i < streams; i++ { + wg.Add(1) + go func() { + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) req := new(Request) req.Message = msg - err := bs.Send(ctx, req) + err = bs.Send(ctx, req) test.Assert(t, err == nil, err) + go func() { + defer wg.Done() + res, err := bs.Recv(ctx) + test.Assert(t, err == nil, err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == msg, res.Message) + + testHeaderAndTrailer(t, bs) + }() + }() + } + wg.Wait() + for { + ng := runtime.NumGoroutine() + if ng-numGoroutine < 10 { + break } - err = bs.CloseSend(ctx) - test.Assert(t, err == nil, err) + runtime.GC() + time.Sleep(time.Second) + t.Logf("current goroutines=%d, before =%d", ng, numGoroutine) + } +} + +func TestTTHeaderStreamingRecvTimeout(t *testing.T) { + var addr = test.GetLocalAddress() + ln, _ := netpoll.CreateListener("tcp", addr) + defer ln.Close() + + // create server + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + // register streamingService as ttstreaam provider + sp, _ := ttstream.NewServerProvider(streamingServiceInfo) + _ = svr.RegisterService( + streamingServiceInfo, + new(streamingService), + streamxserver.WithProvider(sp), + ) + go func() { + _ = svr.Run() }() + defer svr.Stop() + test.WaitServerStart(addr) + + cp, _ := ttstream.NewClientProvider( + streamingServiceInfo, + ttstream.WithClientLongConnPool( + ttstream.LongConnConfig{MaxIdleTimeout: time.Second}, + ), + ) + + // timeout by ctx itself + streamClient, _ := NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + streamxclient.WithProvider(cp), + ) + ctx := context.Background() + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + req := new(Request) + req.Message = string(make([]byte, 1024)) + err = bs.Send(ctx, req) + test.Assert(t, err == nil, err) + ctx, cancel := context.WithCancel(ctx) + cancel() + _, err = bs.Recv(ctx) + test.Assert(t, err != nil, err) + t.Logf("recv timeout error: %v", err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + + // timeout by client WithRecvTimeout + streamClient, _ = NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + streamxclient.WithProvider(cp), + streamxclient.WithRecvTimeout(time.Nanosecond), + ) + ctx = context.Background() + bs, err = streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + req = new(Request) + req.Message = string(make([]byte, 1024)) + err = bs.Send(ctx, req) + test.Assert(t, err == nil, err) + _, err = bs.Recv(ctx) + test.Assert(t, err != nil, err) + t.Logf("recv timeout error: %v", err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) +} + +func BenchmarkTTHeaderStreaming(b *testing.B) { + klog.SetLevel(klog.LevelWarn) + var addr = test.GetLocalAddress() + ln, _ := netpoll.CreateListener("tcp", addr) + defer ln.Close() + + // create server + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + // register streamingService as ttstreaam provider + sp, _ := ttstream.NewServerProvider(streamingServiceInfo) + _ = svr.RegisterService(streamingServiceInfo, new(streamingService), streamxserver.WithProvider(sp)) go func() { - defer wg.Done() - i := 0 - for { - res, err := bs.Recv(ctx) - if errors.Is(err, io.EOF) { - break - } - i++ - test.Assert(t, err == nil, err) - test.Assert(t, msg == res.Message, res.Message) - } - test.Assert(t, i == round, i) + _ = svr.Run() }() - wg.Wait() - atomic.AddInt32(&serverStreamCount, -1) - waitServerStreamDone() - test.Assert(t, serverRecvCount["BidiStream"] == round, serverRecvCount) - test.Assert(t, serverSendCount["BidiStream"] == round, serverSendCount) - bs = nil - runtime.GC() + defer svr.Stop() + test.WaitServerStart(addr) - streamClient = nil + streamClient, _ := NewStreamingClient("kitex.service.streaming", streamxclient.WithHostPorts(addr)) + ctx := context.Background() + bs, err := streamClient.BidiStream(ctx) + if err != nil { + b.Fatal(err) + } + msg := "BidiStream" + var wg sync.WaitGroup + wg.Add(1) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + req := new(Request) + req.Message = msg + err := bs.Send(ctx, req) + if err != nil { + b.Fatal(err) + } + res, err := bs.Recv(ctx) + if errors.Is(err, io.EOF) { + break + } + _ = res + } + err = bs.CloseSend(ctx) } diff --git a/pkg/streamx/provider/ttstream/ttstream_common_test.go b/pkg/streamx/provider/ttstream/ttstream_common_test.go index 9d424c1826..242a75cd10 100644 --- a/pkg/streamx/provider/ttstream/ttstream_common_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_common_test.go @@ -1,7 +1,24 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream_test import ( "context" + "github.com/bytedance/gopkg/cloud/metainfo" ) diff --git a/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go b/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go index 0563679310..2784566ae9 100644 --- a/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go @@ -1,4 +1,18 @@ -// Code generated by Kitex v1.16.4. DO NOT EDIT. +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package ttstream_test diff --git a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go b/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go index 1e777b4988..cab4e6ac20 100644 --- a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream_test import ( @@ -16,13 +32,13 @@ import ( // === gen code === // --- Define Header and Trailer type --- -type ClientStreamingServer[Req, Res any] streamx.ClientStreamingServer[ttstream.Header, ttstream.Trailer, Req, Res] -type ServerStreamingServer[Res any] streamx.ServerStreamingServer[ttstream.Header, ttstream.Trailer, Res] -type BidiStreamingServer[Req, Res any] streamx.BidiStreamingServer[ttstream.Header, ttstream.Trailer, Req, Res] +type ClientStreamingServer[Req, Res any] streamx.ClientStreamingServer[Req, Res] +type ServerStreamingServer[Res any] streamx.ServerStreamingServer[Res] +type BidiStreamingServer[Req, Res any] streamx.BidiStreamingServer[Req, Res] -type ClientStreamingClient[Req, Res any] streamx.ClientStreamingClient[ttstream.Header, ttstream.Trailer, Req, Res] -type ServerStreamingClient[Res any] streamx.ServerStreamingClient[ttstream.Header, ttstream.Trailer, Res] -type BidiStreamingClient[Req, Res any] streamx.BidiStreamingClient[ttstream.Header, ttstream.Trailer, Req, Res] +type ClientStreamingClient[Req, Res any] streamx.ClientStreamingClient[Req, Res] +type ServerStreamingClient[Res any] streamx.ServerStreamingClient[Res] +type BidiStreamingClient[Req, Res any] streamx.BidiStreamingClient[Req, Res] // --- Define Service Method handler --- var pingpongServiceInfo = &serviceinfo.ServiceInfo{ @@ -55,7 +71,7 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ Methods: map[string]serviceinfo.MethodInfo{ "Unary": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + return streamxserver.InvokeStream[Request, Response]( ctx, serviceinfo.StreamingUnary, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, @@ -65,7 +81,7 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ ), "ClientStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + return streamxserver.InvokeStream[Request, Response]( ctx, serviceinfo.StreamingClient, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, @@ -75,7 +91,7 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ ), "ServerStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + return streamxserver.InvokeStream[Request, Response]( ctx, serviceinfo.StreamingServer, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, @@ -85,7 +101,7 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ ), "BidiStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + return streamxserver.InvokeStream[Request, Response]( ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, @@ -94,7 +110,7 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), ), }, - Extra: map[string]interface{}{"streaming": true, "streamx": true}, + Extra: map[string]interface{}{"streamingFlag": true, "streamx": true}, } // --- Define RegisterService interface --- @@ -118,12 +134,12 @@ func NewPingPongClient(destService string, opts ...client.Option) (PingPongClien func NewStreamingClient(destService string, opts ...streamxclient.Option) (StreamingClientInterface, error) { var options []streamxclient.Option options = append(options, streamxclient.WithDestService(destService)) - options = append(options, opts...) cp, err := ttstream.NewClientProvider(streamingServiceInfo) if err != nil { return nil, err } options = append(options, streamxclient.WithProvider(cp)) + options = append(options, opts...) cli, err := streamxclient.NewClient(streamingServiceInfo, options...) if err != nil { return nil, err @@ -179,7 +195,7 @@ func (c *kClient) PingPong(ctx context.Context, req *Request) (r *Response, err func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { res := new(Response) - _, err := streamxclient.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + _, err := streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingUnary, "Unary", req, res, callOptions...) if err != nil { return nil, err @@ -188,18 +204,18 @@ func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...stream } func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream ClientStreamingClient[Request, Response], err error) { - return streamxclient.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) } func (c *kClient) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( stream ServerStreamingClient[Response], err error) { - return streamxclient.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) } func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( stream BidiStreamingClient[Request, Response], err error) { - return streamxclient.InvokeStream[ttstream.Header, ttstream.Trailer, Request, Response]( + return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) } diff --git a/pkg/streamx/provider/ttstream/ttstream_server_test.go b/pkg/streamx/provider/ttstream/ttstream_server_test.go index 91e1559108..30a2e212bf 100644 --- a/pkg/streamx/provider/ttstream/ttstream_server_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_server_test.go @@ -1,33 +1,78 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream_test import ( "context" "io" - "log" + + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" ) type pingpongService struct{} type streamingService struct{} +const ( + headerKey = "header1" + headerVal = "value1" + trailerKey = "trailer1" + trailerVal = "value1" +) + +func (si *streamingService) setHeaderAndTrailer(stream streamx.ServerStreamMetadata) error { + err := stream.SetTrailer(streamx.Trailer{trailerKey: trailerVal}) + if err != nil { + return err + } + err = stream.SendHeader(streamx.Header{headerKey: headerVal}) + if err != nil { + klog.Errorf("send header failed: %v", err) + return err + } + return nil +} + func (si *pingpongService) PingPong(ctx context.Context, req *Request) (*Response, error) { resp := &Response{Type: req.Type, Message: req.Message} - log.Printf("Server PingPong: req={%v} resp={%v}", req, resp) + klog.Infof("Server PingPong: req={%v} resp={%v}", req, resp) return resp, nil } func (si *streamingService) Unary(ctx context.Context, req *Request) (*Response, error) { resp := &Response{Type: req.Type, Message: req.Message} - log.Printf("Server Unary: req={%v} resp={%v}", req, resp) + klog.Infof("Server Unary: req={%v} resp={%v}", req, resp) return resp, nil } -func (si *streamingService) ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (res *Response, err error) { +func (si *streamingService) ClientStream(ctx context.Context, + stream streamx.ClientStreamingServer[Request, Response]) (*Response, error) { var msg string - defer log.Printf("Server ClientStream end") + klog.Infof("Server ClientStream start") + defer klog.Infof("Server ClientStream end") + + if err := si.setHeaderAndTrailer(stream); err != nil { + return nil, err + } for { req, err := stream.Recv(ctx) if err == io.EOF { - res = new(Response) + res := new(Response) res.Message = msg return res, nil } @@ -35,17 +80,17 @@ func (si *streamingService) ClientStream(ctx context.Context, stream ClientStrea return nil, err } msg = req.Message - log.Printf("Server ClientStream: req={%v}", req) + klog.Infof("Server ClientStream: req={%v}", req) } } -func (si *streamingService) ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error { - log.Printf("Server ServerStream: req={%v}", req) +func (si *streamingService) ServerStream(ctx context.Context, req *Request, + stream streamx.ServerStreamingServer[Response]) error { + klog.Infof("Server ServerStream: req={%v}", req) - _ = stream.SetHeader(map[string]string{"key1": "val1"}) - _ = stream.SendHeader(map[string]string{"key2": "val2"}) - _ = stream.SetTrailer(map[string]string{"key1": "val1"}) - _ = stream.SetTrailer(map[string]string{"key2": "val2"}) + if err := si.setHeaderAndTrailer(stream); err != nil { + return err + } for i := 0; i < 3; i++ { resp := new(Response) @@ -55,12 +100,22 @@ func (si *streamingService) ServerStream(ctx context.Context, req *Request, stre if err != nil { return err } - log.Printf("Server ServerStream: send resp={%v}", resp) + klog.Infof("Server ServerStream: send resp={%v}", resp) } return nil } -func (si *streamingService) BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error { +func (si *streamingService) BidiStream(ctx context.Context, + stream streamx.BidiStreamingServer[Request, Response]) error { + ktx.RegisterCancelCallback(ctx, func() { + klog.Debugf("RegisterCancelCallback work!") + }) + klog.Debugf("RegisterCancelCallback registered!") + + if err := si.setHeaderAndTrailer(stream); err != nil { + return err + } + for { req, err := stream.Recv(ctx) if err == io.EOF { @@ -76,6 +131,6 @@ func (si *streamingService) BidiStream(ctx context.Context, stream BidiStreaming if err != nil { return err } - log.Printf("Server BidiStream: req={%v} resp={%v}", req, resp) + klog.Debugf("Server BidiStream: req={%v} resp={%v}", req, resp) } } diff --git a/pkg/streamx/server_provider.go b/pkg/streamx/server_provider.go index 33c551aaf7..cb04ab3232 100644 --- a/pkg/streamx/server_provider.go +++ b/pkg/streamx/server_provider.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( diff --git a/pkg/streamx/server_provider_internal.go b/pkg/streamx/server_provider_internal.go index 2342aecbc0..643753888b 100644 --- a/pkg/streamx/server_provider_internal.go +++ b/pkg/streamx/server_provider_internal.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( diff --git a/pkg/streamx/stream.go b/pkg/streamx/stream.go index b9abdd0bff..42a7b043cd 100644 --- a/pkg/streamx/stream.go +++ b/pkg/streamx/stream.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( @@ -6,12 +22,12 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) -var _ ServerStreamingClient[int, int, int] = (*GenericClientStream[int, int, int, int])(nil) -var _ ClientStreamingClient[int, int, int, int] = (*GenericClientStream[int, int, int, int])(nil) -var _ BidiStreamingClient[int, int, int, int] = (*GenericClientStream[int, int, int, int])(nil) -var _ ServerStreamingServer[int, int, int] = (*GenericServerStream[int, int, int, int])(nil) -var _ ClientStreamingServer[int, int, int, int] = (*GenericServerStream[int, int, int, int])(nil) -var _ BidiStreamingServer[int, int, int, int] = (*GenericServerStream[int, int, int, int])(nil) +var _ ServerStreamingClient[int] = (*GenericClientStream[int, int])(nil) +var _ ClientStreamingClient[int, int] = (*GenericClientStream[int, int])(nil) +var _ BidiStreamingClient[int, int] = (*GenericClientStream[int, int])(nil) +var _ ServerStreamingServer[int] = (*GenericServerStream[int, int])(nil) +var _ ClientStreamingServer[int, int] = (*GenericServerStream[int, int])(nil) +var _ BidiStreamingServer[int, int] = (*GenericServerStream[int, int])(nil) type StreamingMode = serviceinfo.StreamingMode @@ -87,56 +103,56 @@ type ServerStream interface { // client 必须通过 metainfo.WithValue(ctx, ..) 给下游传递信息 // client 必须通过 metainfo.GetValue(ctx, ..) 拿到当前 server 的透传信息 // client 必须通过 Header() 拿到下游 server 的透传信息 -type ClientStreamMetadata[Header, Trailer any] interface { +type ClientStreamMetadata interface { Header() (Header, error) Trailer() (Trailer, error) } // server 可以通过 Set/SendXXX 给上游回传信息 -type ServerStreamMetadata[Header, Trailer any] interface { +type ServerStreamMetadata interface { SetHeader(hd Header) error SendHeader(hd Header) error SetTrailer(hd Trailer) error } -type ServerStreamingClient[Header, Trailer, Res any] interface { +type ServerStreamingClient[Res any] interface { Recv(ctx context.Context) (*Res, error) ClientStream - ClientStreamMetadata[Header, Trailer] + ClientStreamMetadata } -type ServerStreamingServer[Header, Trailer, Res any] interface { +type ServerStreamingServer[Res any] interface { Send(ctx context.Context, res *Res) error ServerStream - ServerStreamMetadata[Header, Trailer] + ServerStreamMetadata } -type ClientStreamingClient[Header, Trailer, Req, Res any] interface { +type ClientStreamingClient[Req, Res any] interface { Send(ctx context.Context, req *Req) error CloseAndRecv(ctx context.Context) (*Res, error) ClientStream - ClientStreamMetadata[Header, Trailer] + ClientStreamMetadata } -type ClientStreamingServer[Header, Trailer, Req, Res any] interface { +type ClientStreamingServer[Req, Res any] interface { Recv(ctx context.Context) (*Req, error) //SendAndClose(ctx context.Context, res *Res) error ServerStream - ServerStreamMetadata[Header, Trailer] + ServerStreamMetadata } -type BidiStreamingClient[Header, Trailer, Req, Res any] interface { +type BidiStreamingClient[Req, Res any] interface { Send(ctx context.Context, req *Req) error Recv(ctx context.Context) (*Res, error) ClientStream - ClientStreamMetadata[Header, Trailer] + ClientStreamMetadata } -type BidiStreamingServer[Header, Trailer, Req, Res any] interface { +type BidiStreamingServer[Req, Res any] interface { Recv(ctx context.Context) (*Req, error) Send(ctx context.Context, res *Res) error ServerStream - ServerStreamMetadata[Header, Trailer] + ServerStreamMetadata } type GenericStreamIOMiddlewareSetter interface { @@ -144,36 +160,36 @@ type GenericStreamIOMiddlewareSetter interface { SetStreamRecvEndpoint(e StreamSendEndpoint) } -func NewGenericClientStream[Header, Trailer, Req, Res any](cs ClientStream) *GenericClientStream[Header, Trailer, Req, Res] { - return &GenericClientStream[Header, Trailer, Req, Res]{ +func NewGenericClientStream[Req, Res any](cs ClientStream) *GenericClientStream[Req, Res] { + return &GenericClientStream[Req, Res]{ ClientStream: cs, - ClientStreamMetadata: cs.(ClientStreamMetadata[Header, Trailer]), + ClientStreamMetadata: cs.(ClientStreamMetadata), } } -type GenericClientStream[Header, Trailer, Req, Res any] struct { +type GenericClientStream[Req, Res any] struct { ClientStream - ClientStreamMetadata[Header, Trailer] + ClientStreamMetadata StreamSendMiddleware StreamRecvMiddleware } -func (x *GenericClientStream[Header, Trailer, Req, Res]) SetStreamSendMiddleware(e StreamSendMiddleware) { +func (x *GenericClientStream[Req, Res]) SetStreamSendMiddleware(e StreamSendMiddleware) { x.StreamSendMiddleware = e } -func (x *GenericClientStream[Header, Trailer, Req, Res]) SetStreamRecvMiddleware(e StreamRecvMiddleware) { +func (x *GenericClientStream[Req, Res]) SetStreamRecvMiddleware(e StreamRecvMiddleware) { x.StreamRecvMiddleware = e } -func (x *GenericClientStream[Header, Trailer, Req, Res]) SendMsg(ctx context.Context, m any) error { +func (x *GenericClientStream[Req, Res]) SendMsg(ctx context.Context, m any) error { if x.StreamSendMiddleware != nil { return x.StreamSendMiddleware(streamSendNext)(ctx, x.ClientStream, m) } return x.ClientStream.SendMsg(ctx, m) } -func (x *GenericClientStream[Header, Trailer, Req, Res]) RecvMsg(ctx context.Context, m any) (err error) { +func (x *GenericClientStream[Req, Res]) RecvMsg(ctx context.Context, m any) (err error) { if x.StreamRecvMiddleware != nil { err = x.StreamRecvMiddleware(streamRecvNext)(ctx, x.ClientStream, m) } else { @@ -182,11 +198,11 @@ func (x *GenericClientStream[Header, Trailer, Req, Res]) RecvMsg(ctx context.Con return err } -func (x *GenericClientStream[Header, Trailer, Req, Res]) Send(ctx context.Context, m *Req) error { +func (x *GenericClientStream[Req, Res]) Send(ctx context.Context, m *Req) error { return x.SendMsg(ctx, m) } -func (x *GenericClientStream[Header, Trailer, Req, Res]) Recv(ctx context.Context) (m *Res, err error) { +func (x *GenericClientStream[Req, Res]) Recv(ctx context.Context) (m *Res, err error) { m = new(Res) if err = x.RecvMsg(ctx, m); err != nil { return nil, err @@ -194,43 +210,43 @@ func (x *GenericClientStream[Header, Trailer, Req, Res]) Recv(ctx context.Contex return m, nil } -func (x *GenericClientStream[Header, Trailer, Req, Res]) CloseAndRecv(ctx context.Context) (*Res, error) { +func (x *GenericClientStream[Req, Res]) CloseAndRecv(ctx context.Context) (*Res, error) { if err := x.ClientStream.CloseSend(ctx); err != nil { return nil, err } return x.Recv(ctx) } -func NewGenericServerStream[Header, Trailer, Req, Res any](ss ServerStream) *GenericServerStream[Header, Trailer, Req, Res] { - return &GenericServerStream[Header, Trailer, Req, Res]{ +func NewGenericServerStream[Req, Res any](ss ServerStream) *GenericServerStream[Req, Res] { + return &GenericServerStream[Req, Res]{ ServerStream: ss, - ServerStreamMetadata: ss.(ServerStreamMetadata[Header, Trailer]), + ServerStreamMetadata: ss.(ServerStreamMetadata), } } -type GenericServerStream[Header, Trailer, Req, Res any] struct { +type GenericServerStream[Req, Res any] struct { ServerStream - ServerStreamMetadata[Header, Trailer] + ServerStreamMetadata StreamSendMiddleware StreamRecvMiddleware } -func (x *GenericServerStream[Header, Trailer, Req, Res]) SetStreamSendMiddleware(e StreamSendMiddleware) { +func (x *GenericServerStream[Req, Res]) SetStreamSendMiddleware(e StreamSendMiddleware) { x.StreamSendMiddleware = e } -func (x *GenericServerStream[Header, Trailer, Req, Res]) SetStreamRecvMiddleware(e StreamRecvMiddleware) { +func (x *GenericServerStream[Req, Res]) SetStreamRecvMiddleware(e StreamRecvMiddleware) { x.StreamRecvMiddleware = e } -func (x *GenericServerStream[Header, Trailer, Req, Res]) SendMsg(ctx context.Context, m any) error { +func (x *GenericServerStream[Req, Res]) SendMsg(ctx context.Context, m any) error { if x.StreamSendMiddleware != nil { return x.StreamSendMiddleware(streamSendNext)(ctx, x.ServerStream, m) } return x.ServerStream.SendMsg(ctx, m) } -func (x *GenericServerStream[Header, Trailer, Req, Res]) RecvMsg(ctx context.Context, m any) (err error) { +func (x *GenericServerStream[Req, Res]) RecvMsg(ctx context.Context, m any) (err error) { if x.StreamRecvMiddleware != nil { err = x.StreamRecvMiddleware(streamRecvNext)(ctx, x.ServerStream, m) } else { @@ -239,18 +255,18 @@ func (x *GenericServerStream[Header, Trailer, Req, Res]) RecvMsg(ctx context.Con return err } -func (x *GenericServerStream[Header, Trailer, Req, Res]) Send(ctx context.Context, m *Res) error { +func (x *GenericServerStream[Req, Res]) Send(ctx context.Context, m *Res) error { if x.StreamSendMiddleware != nil { return x.StreamSendMiddleware(streamSendNext)(ctx, x.ServerStream, m) } return x.ServerStream.SendMsg(ctx, m) } -func (x *GenericServerStream[Header, Trailer, Req, Res]) SendAndClose(ctx context.Context, m *Res) error { +func (x *GenericServerStream[Req, Res]) SendAndClose(ctx context.Context, m *Res) error { return x.Send(ctx, m) } -func (x *GenericServerStream[Header, Trailer, Req, Res]) Recv(ctx context.Context) (m *Req, err error) { +func (x *GenericServerStream[Req, Res]) Recv(ctx context.Context) (m *Req, err error) { m = new(Req) if x.StreamRecvMiddleware != nil { err = x.StreamRecvMiddleware(streamRecvNext)(ctx, x.ServerStream, m) diff --git a/pkg/streamx/stream_args.go b/pkg/streamx/stream_args.go index 1c0482ff30..432ba26d58 100644 --- a/pkg/streamx/stream_args.go +++ b/pkg/streamx/stream_args.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( diff --git a/pkg/streamx/stream_middleware.go b/pkg/streamx/stream_middleware.go index 30ab1548b1..ed2f8e948b 100644 --- a/pkg/streamx/stream_middleware.go +++ b/pkg/streamx/stream_middleware.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( diff --git a/server/streamxserver/server.go b/server/streamxserver/server.go index 14b9fb68b5..890040e190 100644 --- a/server/streamxserver/server.go +++ b/server/streamxserver/server.go @@ -9,7 +9,7 @@ type Server = server.Server func NewServer(opts ...Option) server.Server { iopts := make([]server.Option, 0, len(opts)+1) for _, opt := range opts { - iopts = append(iopts, convertServerOption(opt)) + iopts = append(iopts, ConvertStreamXServerOption(opt)) } s := server.NewServer(iopts...) return s diff --git a/server/streamxserver/server_gen.go b/server/streamxserver/server_gen.go index 7d5ed074bf..9b46f928a1 100644 --- a/server/streamxserver/server_gen.go +++ b/server/streamxserver/server_gen.go @@ -9,7 +9,7 @@ import ( "github.com/cloudwego/kitex/pkg/streamx" ) -func InvokeStream[Header, Trailer, Req, Res any]( +func InvokeStream[Req, Res any]( ctx context.Context, smode serviceinfo.StreamingMode, handler any, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { // prepare args @@ -18,7 +18,7 @@ func InvokeStream[Header, Trailer, Req, Res any]( return errors.New("server stream is nil") } shandler := handler.(streamx.StreamHandler) - gs := streamx.NewGenericServerStream[Header, Trailer, Req, Res](sArgs.Stream()) + gs := streamx.NewGenericServerStream[Req, Res](sArgs.Stream()) gs.SetStreamRecvMiddleware(shandler.StreamRecvMiddleware) gs.SetStreamSendMiddleware(shandler.StreamSendMiddleware) diff --git a/server/streamxserver/server_option.go b/server/streamxserver/server_option.go index ca4558d6a5..f9233703db 100644 --- a/server/streamxserver/server_option.go +++ b/server/streamxserver/server_option.go @@ -12,7 +12,7 @@ type Option internal_server.Option type Options = internal_server.Options func WithListener(ln net.Listener) Option { - return convertInternalServerOption(server.WithListener(ln)) + return ConvertNativeServerOption(server.WithListener(ln)) } func WithStreamMiddleware(mw streamx.StreamMiddleware) server.RegisterOption { @@ -39,10 +39,10 @@ func WithProvider(provider streamx.ServerProvider) server.RegisterOption { }} } -func convertInternalServerOption(o internal_server.Option) Option { +func ConvertNativeServerOption(o internal_server.Option) Option { return Option{F: o.F} } -func convertServerOption(o Option) internal_server.Option { +func ConvertStreamXServerOption(o Option) internal_server.Option { return internal_server.Option{F: o.F} } diff --git a/tool/internal_pkg/tpl/streamx/service.go b/tool/internal_pkg/tpl/streamx/service.go index 3c791012a2..b55bf2e80e 100644 --- a/tool/internal_pkg/tpl/streamx/service.go +++ b/tool/internal_pkg/tpl/streamx/service.go @@ -45,6 +45,7 @@ var svcInfo = &serviceinfo.ServiceInfo{ }, Extra: map[string]interface{}{ "streaming": true, + "streamx": true, }, } From 0ca06023f51e69db91fde8b437623e26c88ce21d Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Thu, 10 Oct 2024 17:13:44 +0800 Subject: [PATCH 05/34] feat: support ttheader streaming error handling (#1566) --- pkg/remote/trans/streamx/server_handler.go | 7 +- pkg/streamx/client_provider.go | 2 +- .../provider/jsonrpc/server_provider.go | 8 +- .../ttstream/client_trans_pool_longconn.go | 3 +- .../ttstream/client_trans_pool_mux.go | 3 +- pkg/streamx/provider/ttstream/exception.go | 6 + pkg/streamx/provider/ttstream/frame.go | 18 ++ .../provider/ttstream/server_provider.go | 33 +++- pkg/streamx/provider/ttstream/stream.go | 48 ++++- pkg/streamx/provider/ttstream/stream_io.go | 19 +- pkg/streamx/provider/ttstream/transport.go | 37 ++-- .../provider/ttstream/transport_test.go | 5 +- .../provider/ttstream/ttstream_client_test.go | 3 +- .../provider/ttstream/ttstream_error_test.go | 174 ++++++++++++++++++ .../ttstream/ttstream_gen_service_test.go | 78 ++++++++ .../provider/ttstream/ttstream_server_test.go | 49 +++++ pkg/streamx/server_provider.go | 6 +- pkg/transmeta/ttheader.go | 29 ++- tool/internal_pkg/generator/generator.go | 3 - tool/internal_pkg/tpl/streamx/client.go | 20 +- .../tpl/streamx/handler.method.go | 6 +- tool/internal_pkg/tpl/streamx/server.go | 6 +- tool/internal_pkg/tpl/streamx/service.go | 2 +- 23 files changed, 491 insertions(+), 74 deletions(-) create mode 100644 pkg/streamx/provider/ttstream/exception.go create mode 100644 pkg/streamx/provider/ttstream/ttstream_error_test.go diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index 011d555a7e..452f40ba47 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -157,7 +157,12 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream reqArgs := streamx.NewStreamReqArgs(nil) resArgs := streamx.NewStreamResArgs(nil) serr := t.inkHdlFunc(ctx, reqArgs, resArgs) - ctx, err = t.provider.OnStreamFinish(ctx, ss) + if serr == nil { + if bizErr := ri.Invocation().BizStatusErr(); bizErr != nil { + serr = bizErr + } + } + ctx, err = t.provider.OnStreamFinish(ctx, ss, serr) if err == nil && serr != nil { err = serr } diff --git a/pkg/streamx/client_provider.go b/pkg/streamx/client_provider.go index a4543ef2f9..7706b38887 100644 --- a/pkg/streamx/client_provider.go +++ b/pkg/streamx/client_provider.go @@ -23,7 +23,7 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" ) -/* Hot it works +/* How it works clientProvider := xxx.NewClientProvider(xxx.WithXXX(...)) client := {user_gencode}.NewClient({kitex_client}.WithClientProvider(clientProvider)) diff --git a/pkg/streamx/provider/jsonrpc/server_provider.go b/pkg/streamx/provider/jsonrpc/server_provider.go index f689c6b7c5..6f74bde151 100644 --- a/pkg/streamx/provider/jsonrpc/server_provider.go +++ b/pkg/streamx/provider/jsonrpc/server_provider.go @@ -20,9 +20,10 @@ import ( "context" "net" + "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/netpoll" ) type serverTransCtxKey struct{} @@ -70,8 +71,7 @@ func (s serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Co return ctx, ss, nil } -func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream) (context.Context, error) { +func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream, err error) (context.Context, error) { sst := ss.(*serverStream) - err := sst.sendEOF() - return ctx, err + return ctx, sst.sendEOF() } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go index c7657ea929..3b3dff2750 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go @@ -19,9 +19,10 @@ package ttstream import ( "time" + "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" - "github.com/cloudwego/netpoll" ) var DefaultLongConnConfig = LongConnConfig{ diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go index 4775cb40bc..f8a9330ff3 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go @@ -21,9 +21,10 @@ import ( "sync" "time" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/netpoll" "golang.org/x/sync/singleflight" + + "github.com/cloudwego/kitex/pkg/serviceinfo" ) var _ transPool = (*muxTransPool)(nil) diff --git a/pkg/streamx/provider/ttstream/exception.go b/pkg/streamx/provider/ttstream/exception.go new file mode 100644 index 0000000000..5c650784b2 --- /dev/null +++ b/pkg/streamx/provider/ttstream/exception.go @@ -0,0 +1,6 @@ +package ttstream + +type tException interface { + Error() string + TypeId() int32 +} diff --git a/pkg/streamx/provider/ttstream/frame.go b/pkg/streamx/provider/ttstream/frame.go index 7732f079d1..97a6bb8938 100644 --- a/pkg/streamx/provider/ttstream/frame.go +++ b/pkg/streamx/provider/ttstream/frame.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/gopkg/bufiox" gopkgthrift "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/streamx" ) @@ -51,6 +52,7 @@ type Frame struct { meta IntHeader typ int32 payload []byte + err error } func newFrame(sframe streamFrame, meta IntHeader, typ int32, payload []byte) (fr *Frame) { @@ -67,11 +69,23 @@ func newFrame(sframe streamFrame, meta IntHeader, typ int32, payload []byte) (fr return fr } +func newErrFrame(err error) (fr *Frame) { + v := framePool.Get() + if v == nil { + fr = new(Frame) + } else { + fr = v.(*Frame) + } + fr.err = err + return fr +} + func recycleFrame(frame *Frame) { frame.streamFrame = streamFrame{} frame.meta = nil frame.typ = 0 frame.payload = nil + frame.err = nil framePool.Put(frame) } @@ -180,3 +194,7 @@ func EncodePayload(ctx context.Context, msg any) ([]byte, error) { func DecodePayload(ctx context.Context, payload []byte, msg any) error { return thrift.UnmarshalThriftData(ctx, thriftCodec, "", payload, msg) } + +func EncodeException(ctx context.Context, method string, seq int32, ex tException) ([]byte, error) { + return gopkgthrift.MarshalFastMsg(method, gopkgthrift.EXCEPTION, seq, ex.(gopkgthrift.FastCodec)) +} diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go index 5ecdfb26ca..7c9add2f80 100644 --- a/pkg/streamx/provider/ttstream/server_provider.go +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -19,13 +19,19 @@ package ttstream import ( "context" "net" + "strconv" "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" - "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/utils" ) type serverTransCtxKey struct{} @@ -86,9 +92,30 @@ func (s serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Co return ctx, ss, nil } -func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream) (context.Context, error) { +func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream, err error) (context.Context, error) { sst := ss.(*serverStream) - _ = sst.close() + var exception tException + if err != nil { + switch err.(type) { + case tException: + exception = err.(tException) + case kerrors.BizStatusErrorIface: + bizErr := err.(kerrors.BizStatusErrorIface) + sst.appendTrailer( + "biz-status", strconv.Itoa(int(bizErr.BizStatusCode())), + "biz-message", bizErr.BizMessage(), + ) + if bizErr.BizExtra() != nil { + extra, _ := utils.Map2JSONStr(bizErr.BizExtra()) + sst.appendTrailer("biz-extra", extra) + } + default: + exception = thrift.NewApplicationException(remote.InternalError, err.Error()) + } + } + if err := sst.close(exception); err != nil { + return nil, err + } cancelFunc, _ := ctx.Value(serverStreamCancelCtxKey{}).(context.CancelFunc) if cancelFunc != nil { diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 20f0915627..c6196e26a2 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -23,9 +23,12 @@ import ( "sync/atomic" "time" + "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/transmeta" ) var ( @@ -72,6 +75,7 @@ type stream struct { peerEOF int32 headerSig chan int32 trailerSig chan int32 + err error StreamMeta metaHandler MetaFrameHandler @@ -157,12 +161,24 @@ func (s *stream) sendHeader() (err error) { // readTrailer by client: unblock recv function and return EOF if no unread frame // readTrailer by server: unblock recv function and return EOF if no unread frame -func (s *stream) readTrailer(tl streamx.Trailer) (err error) { +func (s *stream) readTrailerFrame(fr *Frame) (err error) { if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { return fmt.Errorf("stream read a unexcept trailer") } - s.trailer = tl + // when server-side returns non-biz error, it will be wrapped as ApplicationException stored in trailer frame payload + if len(fr.payload) > 0 { + _, _, ex := thrift.UnmarshalFastMsg(fr.payload, nil) + s.err = ex.(*thrift.ApplicationException) + } else { // when server-side returns biz error, payload is empty and biz error information is stored in trailer frame header + bizErr, err := transmeta.ParseBizStatusErr(fr.trailer) + if err != nil { + s.err = err + } else if bizErr != nil { + s.err = bizErr + } + } + s.trailer = fr.trailer select { case s.trailerSig <- streamSigActive: default: @@ -173,7 +189,8 @@ func (s *stream) readTrailer(tl streamx.Trailer) (err error) { // if trailer arrived, we should return unblock stream.Header() default: } - klog.Debugf("stream[%d] recv trailer: %v", s.sid, tl) + + klog.Debugf("stream[%d] recv trailer: %v, err: %v", s.sid, s.trailer, s.err) return s.trans.streamCloseRecv(s) } @@ -187,7 +204,22 @@ func (s *stream) writeTrailer(tl streamx.Trailer) (err error) { return nil } -func (s *stream) sendTrailer() (err error) { +func (s *stream) appendTrailer(kvs ...string) (err error) { + if len(kvs)%2 != 0 { + return fmt.Errorf("got the odd number of input kvs for Trailer: %d", len(kvs)) + } + var key string + for i, str := range kvs { + if i%2 == 0 { + key = str + continue + } + s.wtrailer[key] = str + } + return nil +} + +func (s *stream) sendTrailer(ctx context.Context, ex tException) (err error) { if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { return nil } @@ -197,7 +229,7 @@ func (s *stream) sendTrailer() (err error) { return fmt.Errorf("stream trailer already sent") } klog.Debugf("transport[%d]-stream[%d] send trialer", s.trans.kind, s.sid) - return s.trans.streamCloseSend(s.sid, s.method, wtrailer) + return s.trans.streamCloseSend(s.sid, s.method, wtrailer, ex) } func (s *stream) finished() bool { @@ -244,7 +276,7 @@ func (s *clientStream) RecvMsg(ctx context.Context, req any) error { } func (s *clientStream) CloseSend(ctx context.Context) error { - return s.sendTrailer() + return s.sendTrailer(ctx, nil) } func newServerStream(s *stream) streamx.ServerStream { @@ -272,9 +304,9 @@ func (s *serverStream) SendMsg(ctx context.Context, res any) error { // close will be called after server handler returned // after close stream cannot be access again -func (s *serverStream) close() error { +func (s *serverStream) close(ex tException) error { // write loop should help to delete stream - err := s.sendTrailer() + err := s.sendTrailer(context.Background(), ex) if err != nil { return err } diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go index a03cc6c0ea..65851ca4dd 100644 --- a/pkg/streamx/provider/ttstream/stream_io.go +++ b/pkg/streamx/provider/ttstream/stream_io.go @@ -37,6 +37,7 @@ type streamIO struct { eofCallback func() fpipe *container.Pipe[*Frame] fcache [1]*Frame + err error } func newStreamIO(ctx context.Context, s *stream) *streamIO { @@ -60,17 +61,27 @@ func (s *streamIO) input(ctx context.Context, f *Frame) { } func (s *streamIO) output(ctx context.Context) (f *Frame, err error) { + if s.err != nil { + return nil, s.err + } n, err := s.fpipe.Read(ctx, s.fcache[:]) if err != nil { if errors.Is(err, container.ErrPipeEOF) { - return nil, io.EOF + err = io.EOF } - return nil, err + s.err = err + return nil, s.err } if n == 0 { - return nil, io.EOF + s.err = io.EOF + return nil, s.err + } + f = s.fcache[0] + if f.err != nil { + s.err = f.err + return nil, s.err } - return s.fcache[0], nil + return f, nil } func (s *streamIO) closeRecv() { diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index eb03cd458f..ab7825dfec 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -28,11 +28,12 @@ import ( "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" - "github.com/cloudwego/netpoll" ) const ( @@ -209,7 +210,7 @@ func (t *transport) loopRead() error { t.kind, addr, fr.sid, fr.trailer) continue } - if err = sio.stream.readTrailer(fr.trailer); err != nil { + if err = sio.stream.readTrailerFrame(fr); err != nil { return err } } @@ -257,15 +258,7 @@ func (t *transport) loopWrite() (err error) { } // writeFrame is concurrent safe -func (t *transport) writeFrame(sframe streamFrame, meta IntHeader, ftype int32, data any) (err error) { - var payload []byte - if data != nil { - // payload should be written nocopy - payload, err = EncodePayload(context.Background(), data) - if err != nil { - return err - } - } +func (t *transport) writeFrame(sframe streamFrame, meta IntHeader, ftype int32, payload []byte) (err error) { frame := newFrame(sframe, meta, ftype, payload) t.wchannel <- frame return nil @@ -278,9 +271,13 @@ func (t *transport) streamSend(ctx context.Context, sid int32, method string, wh return err } } + payload, err := EncodePayload(ctx, res) + if err != nil { + return err + } return t.writeFrame( streamFrame{sid: sid, method: method}, - nil, dataFrameType, res, + nil, dataFrameType, payload, ) } @@ -290,10 +287,17 @@ func (t *transport) streamSendHeader(sid int32, method string, header streamx.He nil, headerFrameType, nil) } -func (t *transport) streamCloseSend(sid int32, method string, trailer streamx.Trailer) (err error) { +func (t *transport) streamCloseSend(sid int32, method string, trailer streamx.Trailer, ex tException) (err error) { + var payload []byte + if ex != nil { + payload, err = EncodeException(context.Background(), method, sid, ex) + if err != nil { + return err + } + } err = t.writeFrame( streamFrame{sid: sid, method: method, trailer: trailer}, - nil, trailerFrameType, nil, + nil, trailerFrameType, payload, ) if err != nil { return err @@ -322,11 +326,14 @@ func (t *transport) streamRecv(ctx context.Context, sid int32, data any) (err er return nil } -func (t *transport) streamCloseRecv(s *stream) (err error) { +func (t *transport) streamCloseRecv(s *stream) error { sio, ok := t.loadStreamIO(s.sid) if !ok { return fmt.Errorf("stream not found in stream map: sid=%d", s.sid) } + if s.err != nil { + sio.input(context.Background(), newErrFrame(s.err)) + } sio.closeRecv() return nil } diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index ec85dd393c..729d49b121 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -26,10 +26,11 @@ import ( "testing" "time" + "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/netpoll" ) func TestTransport(t *testing.T) { @@ -83,7 +84,7 @@ func TestTransport(t *testing.T) { test.Assert(t, err == nil, err) // send trailer - err = ss.(*serverStream).sendTrailer() + err = ss.(*serverStream).sendTrailer(ctx, nil) test.Assert(t, err == nil, err) atomic.AddInt32(&streamDone, -1) }() diff --git a/pkg/streamx/provider/ttstream/ttstream_client_test.go b/pkg/streamx/provider/ttstream/ttstream_client_test.go index 4aa3bbe9bb..10c1e96ede 100644 --- a/pkg/streamx/provider/ttstream/ttstream_client_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_client_test.go @@ -29,6 +29,8 @@ import ( "testing" "time" + "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/streamxclient" "github.com/cloudwego/kitex/internal/test" @@ -39,7 +41,6 @@ import ( "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/streamxserver" "github.com/cloudwego/kitex/transport" - "github.com/cloudwego/netpoll" ) func init() { diff --git a/pkg/streamx/provider/ttstream/ttstream_error_test.go b/pkg/streamx/provider/ttstream/ttstream_error_test.go new file mode 100644 index 0000000000..7153bd60c3 --- /dev/null +++ b/pkg/streamx/provider/ttstream/ttstream_error_test.go @@ -0,0 +1,174 @@ +package ttstream_test + +import ( + "context" + "testing" + "time" + + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/client/streamxclient" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" + "github.com/cloudwego/kitex/server" + "github.com/cloudwego/kitex/server/streamxserver" +) + +const ( + normalErr int32 = iota + 1 + bizErr +) + +var ( + testCode = int32(10001) + testMsg = "biz testMsg" + testExtra = map[string]string{ + "testKey": "testVal", + } + normalErrMsg = "normal error" +) + +func assertNormalErr(t *testing.T, err error) { + ex, ok := err.(*thrift.ApplicationException) + test.Assert(t, ok, err) + test.Assert(t, ex.TypeID() == remote.InternalError, ex.TypeID()) + test.Assert(t, ex.Msg() == "biz error: "+normalErrMsg, ex.Msg()) +} + +func assertBizErr(t *testing.T, err error) { + bizIntf, ok := kerrors.FromBizStatusError(err) + test.Assert(t, ok) + test.Assert(t, bizIntf.BizStatusCode() == testCode, bizIntf.BizStatusCode()) + test.Assert(t, bizIntf.BizMessage() == testMsg, bizIntf.BizMessage()) + test.DeepEqual(t, bizIntf.BizExtra(), testExtra) +} + +func TestTTHeaderStreamingErrorHandling(t *testing.T) { + klog.SetLevel(klog.LevelDebug) + var addr = test.GetLocalAddress() + ln, err := netpoll.CreateListener("tcp", addr) + test.Assert(t, err == nil, err) + defer ln.Close() + + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + sp, err := ttstream.NewServerProvider(streamingServiceInfo) + test.Assert(t, err == nil, err) + err = svr.RegisterService( + streamingServiceInfo, + new(streamingService), + streamxserver.WithProvider(sp), + ) + test.Assert(t, err == nil, err) + go func() { + err := svr.Run() + test.Assert(t, err == nil, err) + }() + defer svr.Stop() + test.WaitServerStart(addr) + + streamClient, err := NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + ) + test.Assert(t, err == nil, err) + + t.Logf("=== UnaryWithErr normalErr ===") + req := new(Request) + req.Type = normalErr + res, err := streamClient.UnaryWithErr(context.Background(), req) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertNormalErr(t, err) + + t.Logf("=== UnaryWithErr bizErr ===") + req = new(Request) + req.Type = bizErr + res, err = streamClient.UnaryWithErr(context.Background(), req) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertBizErr(t, err) + + t.Logf("=== ClientStreamWithErr normalErr ===") + ctx := context.Background() + cliStream, err := streamClient.ClientStreamWithErr(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, cliStream != nil, cliStream) + req = new(Request) + req.Type = normalErr + err = cliStream.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err = cliStream.CloseAndRecv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertNormalErr(t, err) + + t.Logf("=== ClientStreamWithErr bizErr ===") + ctx = context.Background() + cliStream, err = streamClient.ClientStreamWithErr(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, cliStream != nil, cliStream) + req = new(Request) + req.Type = bizErr + err = cliStream.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err = cliStream.CloseAndRecv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertBizErr(t, err) + + t.Logf("=== ServerStreamWithErr normalErr ===") + ctx = context.Background() + req = new(Request) + req.Type = normalErr + svrStream, err := streamClient.ServerStreamWithErr(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, svrStream != nil, svrStream) + res, err = svrStream.Recv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertNormalErr(t, err) + + t.Logf("=== ServerStreamWithErr bizErr ===") + ctx = context.Background() + req = new(Request) + req.Type = bizErr + svrStream, err = streamClient.ServerStreamWithErr(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, svrStream != nil, svrStream) + res, err = svrStream.Recv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertBizErr(t, err) + + t.Logf("=== BidiStreamWithErr normalErr ===") + ctx = context.Background() + bidiStream, err := streamClient.BidiStreamWithErr(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, bidiStream != nil, bidiStream) + req = new(Request) + req.Type = normalErr + err = bidiStream.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err = bidiStream.Recv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertNormalErr(t, err) + + t.Logf("=== BidiStreamWithErr bizErr ===") + ctx = context.Background() + bidiStream, err = streamClient.BidiStreamWithErr(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, bidiStream != nil, bidiStream) + req = new(Request) + req.Type = bizErr + err = bidiStream.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err = bidiStream.Recv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertBizErr(t, err) +} diff --git a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go b/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go index cab4e6ac20..59062f789f 100644 --- a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go @@ -109,6 +109,46 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ false, serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), ), + "UnaryWithErr": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[Request, Response]( + ctx, serviceinfo.StreamingUnary, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingUnary), + ), + "ClientStreamWithErr": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[Request, Response]( + ctx, serviceinfo.StreamingClient, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingClient), + ), + "ServerStreamWithErr": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[Request, Response]( + ctx, serviceinfo.StreamingServer, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingServer), + ), + "BidiStreamWithErr": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[Request, Response]( + ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), }, Extra: map[string]interface{}{"streamingFlag": true, "streamx": true}, } @@ -157,6 +197,10 @@ type StreamingServerInterface interface { ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (*Response, error) ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error + UnaryWithErr(ctx context.Context, req *Request) (*Response, error) + ClientStreamWithErr(ctx context.Context, stream ClientStreamingServer[Request, Response]) (*Response, error) + ServerStreamWithErr(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error + BidiStreamWithErr(ctx context.Context, stream BidiStreamingServer[Request, Response]) error } // --- Define Client Implementation Interface --- @@ -172,6 +216,13 @@ type StreamingClientInterface interface { stream ServerStreamingClient[Response], err error) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( stream BidiStreamingClient[Request, Response], err error) + UnaryWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) + ClientStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream ClientStreamingClient[Request, Response], err error) + ServerStreamWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( + stream ServerStreamingClient[Response], err error) + BidiStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream BidiStreamingClient[Request, Response], err error) } // --- Define Client Implementation --- @@ -219,3 +270,30 @@ func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt. return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) } + +func (c *kClient) UnaryWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { + res := new(Response) + _, err := streamxclient.InvokeStream[Request, Response]( + ctx, c.streamer, serviceinfo.StreamingUnary, "UnaryWithErr", req, res, callOptions...) + if err != nil { + return nil, err + } + return res, nil +} + +func (c *kClient) ClientStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream ClientStreamingClient[Request, Response], err error) { + return streamxclient.InvokeStream[Request, Response]( + ctx, c.streamer, serviceinfo.StreamingClient, "ClientStreamWithErr", nil, nil, callOptions...) +} + +func (c *kClient) ServerStreamWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( + stream ServerStreamingClient[Response], err error) { + return streamxclient.InvokeStream[Request, Response]( + ctx, c.streamer, serviceinfo.StreamingServer, "ServerStreamWithErr", req, nil, callOptions...) +} + +func (c *kClient) BidiStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream BidiStreamingClient[Request, Response], err error) { + return streamxclient.InvokeStream[Request, Response]( + ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStreamWithErr", nil, nil, callOptions...) +} diff --git a/pkg/streamx/provider/ttstream/ttstream_server_test.go b/pkg/streamx/provider/ttstream/ttstream_server_test.go index 30a2e212bf..3f21f24d9c 100644 --- a/pkg/streamx/provider/ttstream/ttstream_server_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_server_test.go @@ -18,8 +18,10 @@ package ttstream_test import ( "context" + "errors" "io" + "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" @@ -134,3 +136,50 @@ func (si *streamingService) BidiStream(ctx context.Context, klog.Debugf("Server BidiStream: req={%v} resp={%v}", req, resp) } } + +func buildErr(req *Request) error { + var err error + switch req.Type { + case normalErr: + err = errors.New(normalErrMsg) + case bizErr: + err = kerrors.NewBizStatusErrorWithExtra(testCode, testMsg, testExtra) + default: + klog.Fatalf("Unsupported Err Type: %d", req.Type) + } + return err +} + +func (si *streamingService) UnaryWithErr(ctx context.Context, req *Request) (*Response, error) { + err := buildErr(req) + klog.Infof("Server UnaryWithErr: req={%v} err={%v}", req, err) + return nil, err +} + +func (si *streamingService) ClientStreamWithErr(ctx context.Context, stream ClientStreamingServer[Request, Response]) (res *Response, err error) { + req, err := stream.Recv(ctx) + if err != nil { + klog.Errorf("Server ClientStreamWithErr Recv failed, err={%v}", err) + return nil, err + } + err = buildErr(req) + klog.Infof("Server ClientStreamWithErr: req={%v} err={%v}", req, err) + return nil, err +} + +func (si *streamingService) ServerStreamWithErr(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error { + err := buildErr(req) + klog.Infof("Server ServerStreamWithErr: req={%v} err={%v}", req, err) + return err +} + +func (si *streamingService) BidiStreamWithErr(ctx context.Context, stream BidiStreamingServer[Request, Response]) error { + req, err := stream.Recv(ctx) + if err != nil { + klog.Errorf("Server BidiStreamWithErr Recv failed, err={%v}", err) + return err + } + err = buildErr(req) + klog.Infof("Server BidiStreamWithErr: req={%v} err={%v}", req, err) + return err +} diff --git a/pkg/streamx/server_provider.go b/pkg/streamx/server_provider.go index cb04ab3232..d044853172 100644 --- a/pkg/streamx/server_provider.go +++ b/pkg/streamx/server_provider.go @@ -21,7 +21,7 @@ import ( "net" ) -/* Hot it works +/* How it works serverProvider := xxx.NewServerProvider(xxx.WithXXX()...) server := {user_gencode}.NewServer({kitex_server}.WithServerProvider(serverProvider)) @@ -44,7 +44,7 @@ res := stream.Recv(...) stream.Close() - server handler return */ -/* Hot it works +/* How it works - NewServer 时,初始化 ServerProvider,并注册 streamx.ServerTransHandler - 连接进来的时候,detection trans handler 会转发调用 streamx.ServerTransHandler - streamx.ServerTransHandler 负责调用 ServerProvider 的相关方法 @@ -63,5 +63,5 @@ type ServerProvider interface { // OnStream should read conn data and return a server stream OnStream(ctx context.Context, conn net.Conn) (context.Context, ServerStream, error) // OnStreamFinish should be called when user server handler returned, typically provide should close the stream - OnStreamFinish(ctx context.Context, ss ServerStream) (context.Context, error) + OnStreamFinish(ctx context.Context, ss ServerStream, err error) (context.Context, error) } diff --git a/pkg/transmeta/ttheader.go b/pkg/transmeta/ttheader.go index 8c1bb81325..416c477e13 100644 --- a/pkg/transmeta/ttheader.go +++ b/pkg/transmeta/ttheader.go @@ -115,20 +115,29 @@ func (ch *clientTTHeaderHandler) ReadMeta(ctx context.Context, msg remote.Messag transInfo := msg.TransInfo() strInfo := transInfo.TransStrInfo() + bizErr, err := ParseBizStatusErr(strInfo) + if err != nil { + return ctx, err + } + if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok && bizErr != nil { + setter.SetBizStatusErr(bizErr) + } + return ctx, nil +} + +func ParseBizStatusErr(strInfo map[string]string) (kerrors.BizStatusErrorIface, error) { if code, err := strconv.Atoi(strInfo[bizStatus]); err == nil && code != 0 { - if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { - if bizExtra := strInfo[bizExtra]; bizExtra != "" { - extra, err := utils.JSONStr2Map(bizExtra) - if err != nil { - return ctx, fmt.Errorf("malformed header info, extra: %s", bizExtra) - } - setter.SetBizStatusErr(kerrors.NewBizStatusErrorWithExtra(int32(code), strInfo[bizMessage], extra)) - } else { - setter.SetBizStatusErr(kerrors.NewBizStatusError(int32(code), strInfo[bizMessage])) + if bizExtra := strInfo[bizExtra]; bizExtra != "" { + extra, err := utils.JSONStr2Map(bizExtra) + if err != nil { + return nil, fmt.Errorf("malformed header info, extra: %s", bizExtra) } + return kerrors.NewBizStatusErrorWithExtra(int32(code), strInfo[bizMessage], extra), nil + } else { + return kerrors.NewBizStatusError(int32(code), strInfo[bizMessage]), nil } } - return ctx, nil + return nil, nil } // serverTTHeaderHandler implement remote.MetaHandler diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 6a82709819..29d444d271 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -728,9 +728,6 @@ func (g *generator) setStreamXServiceImports(pkg *PackageInfo) { pkg.AddImports("context") pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver") - if g.IDLType == "thrift" { - pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef) - } for _, a := range m.Args { for _, dep := range a.Deps { pkg.AddImport(dep.PkgRefName, dep.ImportPath) diff --git a/tool/internal_pkg/tpl/streamx/client.go b/tool/internal_pkg/tpl/streamx/client.go index cf95aded81..f88fce814e 100644 --- a/tool/internal_pkg/tpl/streamx/client.go +++ b/tool/internal_pkg/tpl/streamx/client.go @@ -25,9 +25,9 @@ type Client interface { {{- $bidiSide := and .ClientStreaming .ServerStreaming}} {{- $arg := index .Args 0}} {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (r {{.Resp.Type}}, err error) - {{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.ClientStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) - {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (stream streamx.ServerStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr .Resp.Type}}], err error) - {{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.BidiStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) + {{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.ClientStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (stream streamx.ServerStreamingClient[{{NotPtr .Resp.Type}}], err error) + {{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.BidiStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) {{- end}} {{- end}} } @@ -70,20 +70,20 @@ type kClient struct { {{- $arg := index .Args 0}} func (c *kClient) {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) ({{.Resp.Type}}, error) { res := new({{NotPtr .Resp.Type}}) - _, err := streamxclient.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + _, err := streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( ctx, c.streamer, {{$mode}}, "{{.RawName}}", req, res, callOptions...) if err != nil { return nil, err } return res, nil -{{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.ClientStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) { - return streamxclient.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( +{{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.ClientStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) { + return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( ctx, c.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) -{{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (stream streamx.ServerStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr .Resp.Type}}], err error) { - return streamxclient.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( +{{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (stream streamx.ServerStreamingClient[{{NotPtr .Resp.Type}}], err error) { + return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( ctx, c.streamer, {{$mode}}, "{{.RawName}}", req, nil, callOptions...) -{{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.BidiStreamingClient[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) { - return streamxclient.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( +{{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.BidiStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) { + return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( ctx, c.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) {{- end}} } diff --git a/tool/internal_pkg/tpl/streamx/handler.method.go b/tool/internal_pkg/tpl/streamx/handler.method.go index a46521717f..0224237e3d 100644 --- a/tool/internal_pkg/tpl/streamx/handler.method.go +++ b/tool/internal_pkg/tpl/streamx/handler.method.go @@ -9,9 +9,9 @@ var HandlerMethodsTpl = `{{define "HandlerMethod"}} {{- $bidiSide := and .ClientStreaming .ServerStreaming}} {{- $arg := index .Args 0}} func (s *{{.ServiceName}}Impl) {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}) (resp {{.Resp.Type}}, err error) { - {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (resp {{.Resp.Type}}, err error) { - {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr .Resp.Type}}]) (err error) { - {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (err error) { + {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (resp {{.Resp.Type}}, err error) { + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{NotPtr .Resp.Type}}]) (err error) { + {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (err error) { {{- end}} // TODO: Your code here... return diff --git a/tool/internal_pkg/tpl/streamx/server.go b/tool/internal_pkg/tpl/streamx/server.go index 46daeb4667..217c7b3fa1 100644 --- a/tool/internal_pkg/tpl/streamx/server.go +++ b/tool/internal_pkg/tpl/streamx/server.go @@ -24,9 +24,9 @@ type Server interface { {{- $bidiSide := and .ClientStreaming .ServerStreaming}} {{- $arg := index .Args 0}} {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}) ({{.Resp.Type}}, error) - {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) ({{.Resp.Type}}, error) - {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr .Resp.Type}}]) error - {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) error + {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) ({{.Resp.Type}}, error) + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{NotPtr .Resp.Type}}]) error + {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) error {{- end}} {{- end}} } diff --git a/tool/internal_pkg/tpl/streamx/service.go b/tool/internal_pkg/tpl/streamx/service.go index b55bf2e80e..cbd727d5d0 100644 --- a/tool/internal_pkg/tpl/streamx/service.go +++ b/tool/internal_pkg/tpl/streamx/service.go @@ -33,7 +33,7 @@ var svcInfo = &serviceinfo.ServiceInfo{ {{- end}} "{{.RawName}}": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[{{$protocol}}.Header, {{$protocol}}.Trailer, {{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + return streamxserver.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( ctx, {{$mode}}, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) }, nil, From a99a3c536d9558c8ebb16d214f3187fab91d5d06 Mon Sep 17 00:00:00 2001 From: Joway Date: Thu, 10 Oct 2024 17:24:59 +0800 Subject: [PATCH 06/34] perf: invoker cache and recycle frame perf: mux trans pool perf: recycle frame and rm err in stream io (#1570) perf: invoker cache --- .../ttstream/client_trans_pool_mux.go | 74 ++++-- pkg/streamx/provider/ttstream/frame.go | 22 +- pkg/streamx/provider/ttstream/stream.go | 17 +- pkg/streamx/provider/ttstream/stream_io.go | 56 +++-- pkg/streamx/provider/ttstream/transport.go | 216 ++++++++---------- .../provider/ttstream/ttstream_client_test.go | 5 +- .../provider/ttstream/ttstream_server_test.go | 4 +- server/streamxserver/server_gen.go | 16 +- 8 files changed, 211 insertions(+), 199 deletions(-) diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go index f8a9330ff3..dfdf534dcf 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go @@ -19,6 +19,7 @@ package ttstream import ( "runtime" "sync" + "sync/atomic" "time" "github.com/cloudwego/netpoll" @@ -29,43 +30,76 @@ import ( var _ transPool = (*muxTransPool)(nil) +type muxTransList struct { + L sync.RWMutex + size int + cursor uint32 + transports []*transport +} + +func newMuxTransList(size int) *muxTransList { + tl := new(muxTransList) + tl.size = size + tl.transports = make([]*transport, size) + return tl +} + +func (tl *muxTransList) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (*transport, error) { + idx := atomic.AddUint32(&tl.cursor, 1) % uint32(tl.size) + tl.L.RLock() + trans := tl.transports[idx] + tl.L.RUnlock() + if trans != nil && trans.IsActive() { + return trans, nil + } + + conn, err := netpoll.DialConnection(network, addr, time.Second) + if err != nil { + return nil, err + } + trans = newTransport(clientTransport, sinfo, conn) + _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { + // peer close + _ = trans.Close() + return nil + }) + runtime.SetFinalizer(trans, func(trans *transport) { + // self close when not hold by user + _ = trans.Close() + }) + tl.L.Lock() + tl.transports[idx] = trans + tl.L.Unlock() + return trans, nil +} + func newMuxTransPool() transPool { t := new(muxTransPool) + t.poolSize = runtime.GOMAXPROCS(0) return t } type muxTransPool struct { - pool sync.Map // addr:*transport - sflight singleflight.Group + poolSize int + pool sync.Map // addr:*muxTransList + sflight singleflight.Group } func (m *muxTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) { v, ok := m.pool.Load(addr) if ok { - return v.(*transport), nil + return v.(*muxTransList).Get(sinfo, network, addr) } + v, err, _ = m.sflight.Do(addr, func() (interface{}, error) { - conn, err := netpoll.DialConnection(network, addr, time.Second) - if err != nil { - return nil, err - } - trans = newTransport(clientTransport, sinfo, conn) - _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { - // peer close - _ = trans.Close() - return nil - }) - m.pool.Store(addr, trans) - runtime.SetFinalizer(trans, func(trans *transport) { - // self close when not hold by user - _ = trans.Close() - }) - return trans, nil + transList := newMuxTransList(m.poolSize) + m.pool.Store(addr, transList) + return transList, nil }) if err != nil { return nil, err } - return v.(*transport), nil + return v.(*muxTransList).Get(sinfo, network, addr) } func (m *muxTransPool) Put(trans *transport) { diff --git a/pkg/streamx/provider/ttstream/frame.go b/pkg/streamx/provider/ttstream/frame.go index 97a6bb8938..77b34426a3 100644 --- a/pkg/streamx/provider/ttstream/frame.go +++ b/pkg/streamx/provider/ttstream/frame.go @@ -27,7 +27,6 @@ import ( gopkgthrift "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" - "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/streamx" ) @@ -52,7 +51,6 @@ type Frame struct { meta IntHeader typ int32 payload []byte - err error } func newFrame(sframe streamFrame, meta IntHeader, typ int32, payload []byte) (fr *Frame) { @@ -69,23 +67,11 @@ func newFrame(sframe streamFrame, meta IntHeader, typ int32, payload []byte) (fr return fr } -func newErrFrame(err error) (fr *Frame) { - v := framePool.Get() - if v == nil { - fr = new(Frame) - } else { - fr = v.(*Frame) - } - fr.err = err - return fr -} - func recycleFrame(frame *Frame) { frame.streamFrame = streamFrame{} frame.meta = nil frame.typ = 0 frame.payload = nil - frame.err = nil framePool.Put(frame) } @@ -184,15 +170,13 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro return fr, nil } -var thriftCodec = thrift.NewThriftCodec() - func EncodePayload(ctx context.Context, msg any) ([]byte, error) { - payload, err := thrift.MarshalThriftData(ctx, thriftCodec, msg) - return payload, err + payload := gopkgthrift.FastMarshal(msg.(gopkgthrift.FastCodec)) + return payload, nil } func DecodePayload(ctx context.Context, payload []byte, msg any) error { - return thrift.UnmarshalThriftData(ctx, thriftCodec, "", payload, msg) + return gopkgthrift.FastUnmarshal(payload, msg.(gopkgthrift.FastCodec)) } func EncodeException(ctx context.Context, method string, seq int32, ex tException) ([]byte, error) { diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index c6196e26a2..720e21a21d 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -75,7 +75,6 @@ type stream struct { peerEOF int32 headerSig chan int32 trailerSig chan int32 - err error StreamMeta metaHandler MetaFrameHandler @@ -166,16 +165,18 @@ func (s *stream) readTrailerFrame(fr *Frame) (err error) { return fmt.Errorf("stream read a unexcept trailer") } + var exception error // when server-side returns non-biz error, it will be wrapped as ApplicationException stored in trailer frame payload if len(fr.payload) > 0 { - _, _, ex := thrift.UnmarshalFastMsg(fr.payload, nil) - s.err = ex.(*thrift.ApplicationException) - } else { // when server-side returns biz error, payload is empty and biz error information is stored in trailer frame header + // exception is type of (*thrift.ApplicationException) + _, _, exception = thrift.UnmarshalFastMsg(fr.payload, nil) + } else { + // when server-side returns biz error, payload is empty and biz error information is stored in trailer frame header bizErr, err := transmeta.ParseBizStatusErr(fr.trailer) if err != nil { - s.err = err + exception = err } else if bizErr != nil { - s.err = bizErr + exception = bizErr } } s.trailer = fr.trailer @@ -190,8 +191,8 @@ func (s *stream) readTrailerFrame(fr *Frame) (err error) { default: } - klog.Debugf("stream[%d] recv trailer: %v, err: %v", s.sid, s.trailer, s.err) - return s.trans.streamCloseRecv(s) + klog.Debugf("stream[%d] recv trailer: %v, exception: %v", s.sid, s.trailer, exception) + return s.trans.streamCloseRecv(s, exception) } func (s *stream) writeTrailer(tl streamx.Trailer) (err error) { diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go index 65851ca4dd..83b37264e3 100644 --- a/pkg/streamx/provider/ttstream/stream_io.go +++ b/pkg/streamx/provider/ttstream/stream_io.go @@ -26,18 +26,23 @@ import ( "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" ) +type streamIOMsg struct { + payload []byte + exception error +} + type streamIO struct { - ctx context.Context - trigger chan struct{} - stream *stream + ctx context.Context + trigger chan struct{} + stream *stream + pipe *container.Pipe[streamIOMsg] + cache [1]streamIOMsg + exception error // once has exception, the stream should not work normally again // eofFlag == 2 when both parties send trailers eofFlag int32 // eofCallback will be called when eofFlag == 2 // eofCallback will not be called if stream is not be ended in a normal way eofCallback func() - fpipe *container.Pipe[*Frame] - fcache [1]*Frame - err error } func newStreamIO(ctx context.Context, s *stream) *streamIO { @@ -45,7 +50,7 @@ func newStreamIO(ctx context.Context, s *stream) *streamIO { sio.ctx = ctx sio.trigger = make(chan struct{}) sio.stream = s - sio.fpipe = container.NewPipe[*Frame]() + sio.pipe = container.NewPipe[streamIOMsg]() return sio } @@ -53,39 +58,40 @@ func (s *streamIO) setEOFCallback(f func()) { s.eofCallback = f } -func (s *streamIO) input(ctx context.Context, f *Frame) { - err := s.fpipe.Write(ctx, f) +func (s *streamIO) input(ctx context.Context, msg streamIOMsg) { + err := s.pipe.Write(ctx, msg) if err != nil { - klog.Errorf("fpipe write failed: %v", err) + klog.Errorf("pipe write failed: %v", err) } } -func (s *streamIO) output(ctx context.Context) (f *Frame, err error) { - if s.err != nil { - return nil, s.err +func (s *streamIO) output(ctx context.Context) (msg streamIOMsg, err error) { + if s.exception != nil { + return msg, s.exception } - n, err := s.fpipe.Read(ctx, s.fcache[:]) + + n, err := s.pipe.Read(ctx, s.cache[:]) if err != nil { if errors.Is(err, container.ErrPipeEOF) { err = io.EOF } - s.err = err - return nil, s.err + s.exception = err + return msg, s.exception } if n == 0 { - s.err = io.EOF - return nil, s.err + s.exception = io.EOF + return msg, s.exception } - f = s.fcache[0] - if f.err != nil { - s.err = f.err - return nil, s.err + msg = s.cache[0] + if msg.exception != nil { + s.exception = msg.exception + return msg, s.exception } - return f, nil + return msg, nil } func (s *streamIO) closeRecv() { - s.fpipe.Close() + s.pipe.Close() if atomic.AddInt32(&s.eofFlag, 1) == 2 && s.eofCallback != nil { s.eofCallback() } @@ -98,5 +104,5 @@ func (s *streamIO) closeSend() { } func (s *streamIO) cancel() { - s.fpipe.Cancel() + s.pipe.Cancel() } diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index ab7825dfec..2944426c58 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -18,7 +18,6 @@ package ttstream import ( "context" - "encoding/binary" "errors" "fmt" "io" @@ -27,6 +26,7 @@ import ( "time" "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/netpoll" @@ -41,7 +41,7 @@ const ( serverTransport int32 = 2 streamCacheSize = 32 - frameChanSize = 32 + frameCacheSize = 256 ) func isIgnoreError(err error) bool { @@ -54,9 +54,8 @@ type transport struct { conn netpoll.Connection streams sync.Map // key=streamID val=streamIO scache []*stream // size is streamCacheSize - spipe *container.Pipe[*stream] // in-coming stream channel - wchannel chan *Frame - closed chan struct{} + spipe *container.Pipe[*stream] // in-coming stream pipe + fpipe *container.Pipe[*Frame] // out-coming frame pipe closedFlag int32 streamingFlag int32 // flag == 0 means there is no active stream on transport } @@ -66,14 +65,13 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne // TODO: let it configurable _ = conn.SetReadTimeout(time.Minute * 10) t := &transport{ - kind: kind, - sinfo: sinfo, - conn: conn, - streams: sync.Map{}, - spipe: container.NewPipe[*stream](), - scache: make([]*stream, 0, streamCacheSize), - wchannel: make(chan *Frame, frameChanSize), - closed: make(chan struct{}), + kind: kind, + sinfo: sinfo, + conn: conn, + streams: sync.Map{}, + spipe: container.NewPipe[*stream](), + scache: make([]*stream, 0, streamCacheSize), + fpipe: container.NewPipe[*Frame](), } go func() { err := t.loopRead() @@ -105,9 +103,9 @@ func (t *transport) Close() (err error) { if !atomic.CompareAndSwapInt32(&t.closedFlag, 0, 1) { return nil } - close(t.closed) klog.Debugf("transport[%s] is closing", t.conn.LocalAddr()) t.spipe.Close() + t.fpipe.Close() t.streams.Range(func(key, value any) bool { sio := value.(*streamIO) sio.stream.close() @@ -136,123 +134,101 @@ func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { return sio, true } -func (t *transport) loopRead() error { - addr := t.conn.RemoteAddr().String() - if t.kind == clientTransport { - addr = t.conn.LocalAddr().String() +func (t *transport) readFrame(reader bufiox.Reader) error { + fr, err := DecodeFrame(context.Background(), reader) + if err != nil { + return err } - for { - // decode frame - sizeBuf, err := t.conn.Reader().Peek(4) - if err != nil { - return err - } - size := binary.BigEndian.Uint32(sizeBuf) - slice, err := t.conn.Reader().Slice(int(size + 4)) - if err != nil { - return err - } - reader := newReaderBuffer(slice) - fr, err := DecodeFrame(context.Background(), reader) - if err != nil { - return err - } - klog.Debugf("transport[%d-%s] DecodeFrame: fr=%v", t.kind, addr, fr) + defer recycleFrame(fr) + klog.Debugf("transport[%d] DecodeFrame: fr=%v", t.kind, fr) - switch fr.typ { - case metaFrameType: - sio, ok := t.loadStreamIO(fr.sid) - if !ok { - klog.Errorf("transport[%d-%s] read a unknown stream meta: sid=%d", t.kind, addr, fr.sid) - continue - } + switch fr.typ { + case metaFrameType: + sio, ok := t.loadStreamIO(fr.sid) + if ok { err = sio.stream.readMetaFrame(fr.meta, fr.header, fr.payload) - if err != nil { - return err - } - case headerFrameType: - switch t.kind { - case serverTransport: - // Header Frame: server recv a new stream - smode := t.sinfo.MethodInfo(fr.method).StreamingMode() - s := newStream(t, smode, fr.streamFrame) - t.storeStreamIO(context.Background(), s) - t.spipe.Write(context.Background(), s) - case clientTransport: - // Header Frame: client recv header - sio, ok := t.loadStreamIO(fr.sid) - if !ok { - klog.Errorf("transport[%d-%s] read a unknown stream header: sid=%d header=%v", - t.kind, addr, fr.sid, fr.header) - continue - } - err = sio.stream.readHeader(fr.header) - if err != nil { - return err - } - } - case dataFrameType: - // Data Frame: decode and distribute data - sio, ok := t.loadStreamIO(fr.sid) - if !ok { - klog.Errorf("transport[%d-%s] read a unknown stream data: sid=%d", t.kind, addr, fr.sid) - continue - } - sio.input(context.Background(), fr) - case trailerFrameType: - // Trailer Frame: recv trailer, Close read direction + } else { + klog.Errorf("transport[%d] read a unknown stream meta: sid=%d", t.kind, fr.sid) + } + case headerFrameType: + switch t.kind { + case serverTransport: + // Header Frame: server recv a new stream + smode := t.sinfo.MethodInfo(fr.method).StreamingMode() + s := newStream(t, smode, fr.streamFrame) + t.storeStreamIO(context.Background(), s) + err = t.spipe.Write(context.Background(), s) + case clientTransport: + // Header Frame: client recv header sio, ok := t.loadStreamIO(fr.sid) - if !ok { - // client recv an unknown trailer is in exception, - // because the client stream may already be GCed, - // but the connection is still active so peer server can send a trailer - klog.Errorf("transport[%d-%s] read a unknown stream trailer: sid=%d trailer=%v", - t.kind, addr, fr.sid, fr.trailer) - continue - } - if err = sio.stream.readTrailerFrame(fr); err != nil { - return err + if ok { + err = sio.stream.readHeader(fr.header) + } else { + klog.Errorf("transport[%d] read a unknown stream header: sid=%d header=%v", + t.kind, fr.sid, fr.header) } } + case dataFrameType: + // Data Frame: decode and distribute data + sio, ok := t.loadStreamIO(fr.sid) + if ok { + sio.input(context.Background(), streamIOMsg{payload: fr.payload}) + } else { + klog.Errorf("transport[%d] read a unknown stream data: sid=%d", t.kind, fr.sid) + } + case trailerFrameType: + // Trailer Frame: recv trailer, Close read direction + sio, ok := t.loadStreamIO(fr.sid) + if ok { + err = sio.stream.readTrailerFrame(fr) + } else { + // client recv an unknown trailer is in exception, + // because the client stream may already be GCed, + // but the connection is still active so peer server can send a trailer + klog.Errorf("transport[%d] read a unknown stream trailer: sid=%d trailer=%v", + t.kind, fr.sid, fr.trailer) + } } + return err } -func (t *transport) loopWrite() (err error) { +func (t *transport) loopRead() error { + reader := newReaderBuffer(t.conn.Reader()) + for { + err := t.readFrame(reader) + // read frame return an un-recovered error, so we should close the transport + if err != nil { + return err + } + } +} + +func (t *transport) loopWrite() error { defer func() { // loop write should help to close connection _ = t.conn.Close() }() writer := newWriterBuffer(t.conn.Writer()) - delay := 0 + fcache := make([]*Frame, frameCacheSize) // Important note: // loopWrite may cannot find stream by sid since it may send trailer and delete sid from streams for { - select { - case <-t.closed: - return nil - case fr, ok := <-t.wchannel: - if !ok { - // closed - return nil - } - select { - case <-t.closed: - // double check closed - return nil - default: - } - + n, err := t.fpipe.Read(context.Background(), fcache) + if err != nil { + return err + } + if n == 0 { + return io.EOF + } + for i := 0; i < n; i++ { + fr := fcache[i] if err = EncodeFrame(context.Background(), writer, fr); err != nil { return err } - if delay >= 8 || len(t.wchannel) == 0 { - delay = 0 - if err = t.conn.Writer().Flush(); err != nil { - return err - } - } else { - delay++ - } + recycleFrame(fr) + } + if err = t.conn.Writer().Flush(); err != nil { + return err } } } @@ -260,8 +236,7 @@ func (t *transport) loopWrite() (err error) { // writeFrame is concurrent safe func (t *transport) writeFrame(sframe streamFrame, meta IntHeader, ftype int32, payload []byte) (err error) { frame := newFrame(sframe, meta, ftype, payload) - t.wchannel <- frame - return nil + return t.fpipe.Write(context.Background(), frame) } func (t *transport) streamSend(ctx context.Context, sid int32, method string, wheader streamx.Header, res any) (err error) { @@ -315,24 +290,23 @@ func (t *transport) streamRecv(ctx context.Context, sid int32, data any) (err er if !ok { return io.EOF } - f, err := sio.output(ctx) + msg, err := sio.output(ctx) if err != nil { return err } - err = DecodePayload(context.Background(), f.payload, data.(thrift.FastCodec)) + err = DecodePayload(context.Background(), msg.payload, data.(thrift.FastCodec)) // payload will not be access after decode - mcache.Free(f.payload) - recycleFrame(f) - return nil + mcache.Free(msg.payload) + return err } -func (t *transport) streamCloseRecv(s *stream) error { +func (t *transport) streamCloseRecv(s *stream, exception error) error { sio, ok := t.loadStreamIO(s.sid) if !ok { return fmt.Errorf("stream not found in stream map: sid=%d", s.sid) } - if s.err != nil { - sio.input(context.Background(), newErrFrame(s.err)) + if exception != nil { + sio.input(context.Background(), streamIOMsg{exception: exception}) } sio.closeRecv() return nil diff --git a/pkg/streamx/provider/ttstream/ttstream_client_test.go b/pkg/streamx/provider/ttstream/ttstream_client_test.go index 10c1e96ede..df45e69d3b 100644 --- a/pkg/streamx/provider/ttstream/ttstream_client_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_client_test.go @@ -271,7 +271,7 @@ func TestTTHeaderStreaming(t *testing.T) { t.Logf("Client ClientStream CloseAndRecv: %v", res) atomic.AddInt32(&serverStreamCount, -1) waitServerStreamDone() - test.Assert(t, serverRecvCount == int32(round), serverRecvCount) + test.DeepEqual(t, serverRecvCount, int32(round)) test.Assert(t, serverSendCount == 1, serverSendCount) testHeaderAndTrailer(t, cs) cs = nil @@ -568,4 +568,7 @@ func BenchmarkTTHeaderStreaming(b *testing.B) { _ = res } err = bs.CloseSend(ctx) + if err != nil { + b.Fatal(err) + } } diff --git a/pkg/streamx/provider/ttstream/ttstream_server_test.go b/pkg/streamx/provider/ttstream/ttstream_server_test.go index 3f21f24d9c..3eecb2b665 100644 --- a/pkg/streamx/provider/ttstream/ttstream_server_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_server_test.go @@ -159,7 +159,7 @@ func (si *streamingService) UnaryWithErr(ctx context.Context, req *Request) (*Re func (si *streamingService) ClientStreamWithErr(ctx context.Context, stream ClientStreamingServer[Request, Response]) (res *Response, err error) { req, err := stream.Recv(ctx) if err != nil { - klog.Errorf("Server ClientStreamWithErr Recv failed, err={%v}", err) + klog.Errorf("Server ClientStreamWithErr Recv failed, exception={%v}", err) return nil, err } err = buildErr(req) @@ -176,7 +176,7 @@ func (si *streamingService) ServerStreamWithErr(ctx context.Context, req *Reques func (si *streamingService) BidiStreamWithErr(ctx context.Context, stream BidiStreamingServer[Request, Response]) error { req, err := stream.Recv(ctx) if err != nil { - klog.Errorf("Server BidiStreamWithErr Recv failed, err={%v}", err) + klog.Errorf("Server BidiStreamWithErr Recv failed, exception={%v}", err) return err } err = buildErr(req) diff --git a/server/streamxserver/server_gen.go b/server/streamxserver/server_gen.go index 9b46f928a1..f9ec27aad2 100644 --- a/server/streamxserver/server_gen.go +++ b/server/streamxserver/server_gen.go @@ -4,11 +4,14 @@ import ( "context" "errors" "reflect" + "sync" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" ) +var invokerCache sync.Map + func InvokeStream[Req, Res any]( ctx context.Context, smode serviceinfo.StreamingMode, handler any, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { @@ -36,9 +39,16 @@ func InvokeStream[Req, Res any]( } // handler call - // TODO: cache handler - rhandler := reflect.ValueOf(shandler.Handler) - mhandler := rhandler.MethodByName(sArgs.Stream().Method()) + cacheKey := reflect.TypeOf(shandler.Handler).String() + sArgs.Stream().Method() + var mhandler reflect.Value + if v, ok := invokerCache.Load(cacheKey); ok { + mhandler = v.(reflect.Value) + } else { + rhandler := reflect.ValueOf(shandler.Handler) + mhandler = rhandler.MethodByName(sArgs.Stream().Method()) + invokerCache.Store(cacheKey, mhandler) + } + streamInvoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { switch smode { case serviceinfo.StreamingUnary: From ed5b1d4b2a9e420ade84234f5b1d90feb1ed59f9 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Fri, 11 Oct 2024 18:29:55 +0800 Subject: [PATCH 07/34] feat: stream event metrics --- client/stream.go | 5 ++++ go.sum | 19 ++------------- internal/server/option.go | 5 ---- internal/stream/stream_option.go | 2 +- pkg/remote/trans/streamx/server_handler.go | 7 ++++-- pkg/streamx/client_options.go | 9 +++++++- pkg/streamx/stream_middleware_internal.go | 27 ++++++++++++++++++++++ server/server.go | 12 ++++++++++ 8 files changed, 60 insertions(+), 26 deletions(-) create mode 100644 pkg/streamx/stream_middleware_internal.go diff --git a/client/stream.go b/client/stream.go index 0e94ea13f8..073dbc4f73 100644 --- a/client/stream.go +++ b/client/stream.go @@ -82,6 +82,11 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { // streamx version streaming mw kc.sxStreamMW = streamx.StreamMiddlewareChain(kc.opt.StreamXOptions.StreamMWs...) + eventHandler := kc.opt.TracerCtl.GetStreamEventHandler() + if eventHandler != nil { + kc.opt.StreamXOptions.StreamRecvMWs = append(kc.opt.StreamXOptions.StreamRecvMWs, streamx.NewStreamRecvStatMiddleware(eventHandler)) + kc.opt.StreamXOptions.StreamSendMWs = append(kc.opt.StreamXOptions.StreamSendMWs, streamx.NewStreamSendStatMiddleware(eventHandler)) + } kc.sxStreamRecvMW = streamx.StreamRecvMiddlewareChain(kc.opt.StreamXOptions.StreamRecvMWs...) kc.sxStreamSendMW = streamx.StreamSendMiddlewareChain(kc.opt.StreamXOptions.StreamSendMWs...) diff --git a/go.sum b/go.sum index 806d9d39d1..cf1260511b 100644 --- a/go.sum +++ b/go.sum @@ -22,29 +22,14 @@ github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZ github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= -github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682 h1:hj/AhlEngERp5Tjt864veEvyK6RglXKcXpxkIOSRfug= -github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= +github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4 h1:SHw9GUBBcAnLWeK2MtPH7O6YQG9Q2ZZ8koD/4alpLvE= +github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= -<<<<<<< HEAD github.com/cloudwego/localsession v0.1.1 h1:tbK7laDVrYfFDXoBXo4uCGMAxU4qmz2dDm8d4BGBnDo= github.com/cloudwego/localsession v0.1.1/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= github.com/cloudwego/netpoll v0.6.4 h1:z/dA4sOTUQof6zZIO4QNnLBXsDFFFEos9OOGloR6kno= github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= -======= -github.com/cloudwego/localsession v0.0.2 h1:N9/IDtCPj1fCL9bCTP+DbXx3f40YjVYWcwkJG0YhQkY= -github.com/cloudwego/localsession v0.0.2/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= -<<<<<<< HEAD -github.com/cloudwego/netpoll v0.6.5-0.20240905095957-e6ec47be2fe0 h1:2aoCxK8fee7LhwWveg3ORVEDBoMtmTY2NuSAtNGpnFI= -github.com/cloudwego/netpoll v0.6.5-0.20240905095957-e6ec47be2fe0/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= -<<<<<<< HEAD ->>>>>>> 9e44721 (feat: support multi service (#1538)) -======= -======= ->>>>>>> 2b0e374 (perf: optimise pipe and queue locker) -github.com/cloudwego/netpoll v0.6.5-0.20240911073319-2ec9568b10cf h1:c/K4XrkloCgZp+En3LjbXtqfr0KQwC85utUvdDm76V4= -github.com/cloudwego/netpoll v0.6.5-0.20240911073319-2ec9568b10cf/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= ->>>>>>> 968bdfc (chore: fix unit test (#1545)) github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= github.com/cloudwego/runtimex v0.1.0/go.mod h1:23vL/HGV0W8nSCHbe084AgEBdDV4rvXenEUMnUNvUd8= github.com/cloudwego/thriftgo v0.3.18 h1:gnr1vz7G3RbwwCK9AMKHZf63VYGa7ene6WbI9VrBJSw= diff --git a/internal/server/option.go b/internal/server/option.go index 82138d139c..ddf0ab82db 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -41,7 +41,6 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" - "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/localsession/backup" @@ -80,10 +79,6 @@ type Options struct { Limit Limit MWBs []endpoint.MiddlewareBuilder - // streamx - SMWs []streamx.StreamMiddleware - SRecvMWs []streamx.StreamRecvMiddleware - SSendMWs []streamx.StreamSendMiddleware Bus event.Bus Events event.Queue diff --git a/internal/stream/stream_option.go b/internal/stream/stream_option.go index 8bdf51acd3..a8be1055fc 100644 --- a/internal/stream/stream_option.go +++ b/internal/stream/stream_option.go @@ -24,7 +24,7 @@ import ( ) // StreamEventHandler is used to handle stream events -type StreamEventHandler func(ctx context.Context, evt stats.Event, err error) +type StreamEventHandler = func(ctx context.Context, evt stats.Event, err error) type StreamingConfig struct { RecvMiddlewareBuilders []endpoint.RecvMiddlewareBuilder diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index 452f40ba47..d6be4021f4 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -23,9 +23,10 @@ import ( "log" "net" "runtime/debug" + "time" + "github.com/cloudwego/kitex/internal/wpool" "github.com/cloudwego/kitex/pkg/endpoint" - "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -42,6 +43,8 @@ import ( 其他接口实际上最终是用来去组装了 transpipeline .... */ +var streamWorkerPool = wpool.New(128, time.Minute) + type svrTransHandlerFactory struct { provider streamx.ServerProvider } @@ -100,7 +103,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { return nerr } // stream level goroutine - gofunc.GoFunc(ctx, func() { + streamWorkerPool.GoCtx(ctx, func() { err := t.OnStream(nctx, conn, ss) if err != nil && !errors.Is(err, io.EOF) { klog.CtxErrorf(ctx, "KITEX: stream ReadStream failed: err=%v", nerr) diff --git a/pkg/streamx/client_options.go b/pkg/streamx/client_options.go index 333fa46eb2..9beb142357 100644 --- a/pkg/streamx/client_options.go +++ b/pkg/streamx/client_options.go @@ -1,6 +1,13 @@ package streamx -import "time" +import ( + "context" + "time" + + "github.com/cloudwego/kitex/pkg/stats" +) + +type EventHandler func(ctx context.Context, evt stats.Event, err error) type ClientOptions struct { RecvTimeout time.Duration diff --git a/pkg/streamx/stream_middleware_internal.go b/pkg/streamx/stream_middleware_internal.go new file mode 100644 index 0000000000..ab2d9f35f0 --- /dev/null +++ b/pkg/streamx/stream_middleware_internal.go @@ -0,0 +1,27 @@ +package streamx + +import ( + "context" + + "github.com/cloudwego/kitex/pkg/stats" +) + +func NewStreamRecvStatMiddleware(ehandler EventHandler) StreamRecvMiddleware { + return func(next StreamRecvEndpoint) StreamRecvEndpoint { + return func(ctx context.Context, stream Stream, res any) (err error) { + err = next(ctx, stream, res) + ehandler(ctx, stats.StreamRecv, err) + return err + } + } +} + +func NewStreamSendStatMiddleware(ehandler EventHandler) StreamSendMiddleware { + return func(next StreamSendEndpoint) StreamSendEndpoint { + return func(ctx context.Context, stream Stream, res any) (err error) { + err = next(ctx, stream, res) + ehandler(ctx, stats.StreamSend, err) + return err + } + } +} diff --git a/server/server.go b/server/server.go index e5cd271cb4..0e939b2d68 100644 --- a/server/server.go +++ b/server/server.go @@ -216,6 +216,18 @@ func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler inter } registerOpts := internal_server.NewRegisterOptions(opts) + // add trace middlewares + ehandler := s.opt.TracerCtl.GetStreamEventHandler() + if ehandler != nil { + registerOpts.StreamRecvMiddlewares = append( + registerOpts.StreamRecvMiddlewares, streamx.NewStreamRecvStatMiddleware(ehandler), + ) + registerOpts.StreamSendMiddlewares = append( + registerOpts.StreamSendMiddlewares, streamx.NewStreamSendStatMiddleware(ehandler), + ) + } + + // register service if err := s.svcs.addService(svcInfo, handler, registerOpts); err != nil { panic(err.Error()) } From ceb8a9edb227334bb01f3e55054c7ef0c1b2105f Mon Sep 17 00:00:00 2001 From: Joway Date: Wed, 16 Oct 2024 10:38:19 +0800 Subject: [PATCH 08/34] fix: stream server close sio pipe when peer close (#1578) --- client/client_streamx.go | 1 + pkg/remote/trans/streamx/server_handler.go | 2 +- .../provider/ttstream/client_provier.go | 22 +- .../provider/ttstream/container/pipe.go | 19 +- pkg/streamx/provider/ttstream/stream.go | 4 - pkg/streamx/provider/ttstream/stream_io.go | 6 + pkg/streamx/provider/ttstream/transport.go | 11 +- .../provider/ttstream/transport_test.go | 203 ------------------ .../provider/ttstream/ttstream_client_test.go | 58 ++++- .../provider/ttstream/ttstream_server_test.go | 6 +- 10 files changed, 95 insertions(+), 237 deletions(-) delete mode 100644 pkg/streamx/provider/ttstream/transport_test.go diff --git a/client/client_streamx.go b/client/client_streamx.go index 64e2d8a672..60d1e767bf 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -43,6 +43,7 @@ func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOp // it's an ugly trick but if we don't want to refactor too much, // this is the only way to compatible with current endpoint design err = kc.sEps(ctx, req, streamArgs) + kc.opt.TracerCtl.DoFinish(ctx, ri, err) if err != nil { return nil, err } diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index d6be4021f4..84132bdcd7 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -43,7 +43,7 @@ import ( 其他接口实际上最终是用来去组装了 transpipeline .... */ -var streamWorkerPool = wpool.New(128, time.Minute) +var streamWorkerPool = wpool.New(128, time.Second) type svrTransHandlerFactory struct { provider streamx.ServerProvider diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index 0762bf30f6..5c82721a5b 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -90,7 +90,7 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callO // if ctx from server side, we should cancel the stream when server handler already returned // TODO: this canceling transmit should be configurable ktx.RegisterCancelCallback(ctx, func() { - sio.stream.cancel() + sio.cancel() }) cs := newClientStream(sio.stream) @@ -98,12 +98,8 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callO var ended uint32 sio.setEOFCallback(func() { // if stream is ended by both parties, put the transport back to pool - sio.stream.close() if atomic.AddUint32(&ended, 1) == 2 { - if trans.IsActive() { - c.transPool.Put(trans) - } - err = trans.streamDelete(sio.stream.sid) + _ = c.streamFinalize(sio, trans) } }) runtime.SetFinalizer(cs, func(cstream *clientStream) { @@ -112,11 +108,17 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callO _ = cstream.CloseSend(ctx) // only delete stream when clientStream be finalized if atomic.AddUint32(&ended, 1) == 2 { - if trans.IsActive() { - c.transPool.Put(trans) - } - err = trans.streamDelete(sio.stream.sid) + _ = c.streamFinalize(sio, trans) } }) return cs, err } + +func (c clientProvider) streamFinalize(sio *streamIO, trans *transport) error { + sio.close() + err := trans.streamDelete(sio.stream.sid) + if trans.IsActive() { + c.transPool.Put(trans) + } + return err +} diff --git a/pkg/streamx/provider/ttstream/container/pipe.go b/pkg/streamx/provider/ttstream/container/pipe.go index 15c92590ed..6847002003 100644 --- a/pkg/streamx/provider/ttstream/container/pipe.go +++ b/pkg/streamx/provider/ttstream/container/pipe.go @@ -49,7 +49,7 @@ type Pipe[Item any] struct { func NewPipe[Item any]() *Pipe[Item] { p := new(Pipe[Item]) p.queue = NewQueue[Item]() - p.trigger = make(chan struct{}, 1) + p.trigger = make(chan struct{}) return p } @@ -86,6 +86,7 @@ READ: if err != nil { return 0, err } + return 0, fmt.Errorf("unknown err") } goto READ } @@ -112,11 +113,19 @@ func (p *Pipe[Item]) Write(ctx context.Context, items ...Item) error { } func (p *Pipe[Item]) Close() { - atomic.StoreInt32(&p.state, pipeStateClosed) - close(p.trigger) + select { + case <-p.trigger: + default: + atomic.StoreInt32(&p.state, pipeStateClosed) + close(p.trigger) + } } func (p *Pipe[Item]) Cancel() { - atomic.StoreInt32(&p.state, pipeStateCanceled) - close(p.trigger) + select { + case <-p.trigger: + default: + atomic.StoreInt32(&p.state, pipeStateCanceled) + close(p.trigger) + } } diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 720e21a21d..f426ed3bf7 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -259,10 +259,6 @@ func (s *stream) RecvMsg(ctx context.Context, req any) error { return s.trans.streamRecv(ctx, s.sid, req) } -func (s *stream) cancel() { - _ = s.trans.streamCancel(s) -} - func newClientStream(s *stream) *clientStream { cs := &clientStream{stream: s} return cs diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go index 83b37264e3..8eabdd5076 100644 --- a/pkg/streamx/provider/ttstream/stream_io.go +++ b/pkg/streamx/provider/ttstream/stream_io.go @@ -103,6 +103,12 @@ func (s *streamIO) closeSend() { } } +func (s *streamIO) close() { + s.stream.close() + s.pipe.Close() +} + func (s *streamIO) cancel() { s.pipe.Cancel() + s.stream.close() } diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index 2944426c58..e7686bfdf3 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -108,7 +108,7 @@ func (t *transport) Close() (err error) { t.fpipe.Close() t.streams.Range(func(key, value any) bool { sio := value.(*streamIO) - sio.stream.close() + sio.close() _ = t.streamDelete(sio.stream.sid) return true }) @@ -312,15 +312,6 @@ func (t *transport) streamCloseRecv(s *stream, exception error) error { return nil } -func (t *transport) streamCancel(s *stream) (err error) { - sio, ok := t.loadStreamIO(s.sid) - if !ok { - return fmt.Errorf("stream not found in stream map: sid=%d", s.sid) - } - sio.cancel() - return nil -} - func (t *transport) streamDelete(sid int32) (err error) { // remove stream from transport _, ok := t.streams.LoadAndDelete(sid) diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go deleted file mode 100644 index 729d49b121..0000000000 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream - -import ( - "context" - "errors" - "io" - "net" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/cloudwego/netpoll" - - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/streamx" -) - -func TestTransport(t *testing.T) { - method := "BidiStream" - sinfo := &serviceinfo.ServiceInfo{ - ServiceName: "a.b.c", - Methods: map[string]serviceinfo.MethodInfo{ - method: serviceinfo.NewMethodInfo( - func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return nil - }, - nil, - nil, - false, - serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), - ), - }, - Extra: map[string]interface{}{"streaming": true}, - } - - addr := test.GetLocalAddress() - ln, err := net.Listen("tcp", addr) - test.Assert(t, err == nil, err) - defer ln.Close() - - var connDone int32 - var streamDone int32 - svr, err := netpoll.NewEventLoop(nil, - netpoll.WithOnConnect(func(ctx context.Context, connection netpoll.Connection) context.Context { - t.Logf("OnConnect started") - defer t.Logf("OnConnect finished") - trans := newTransport(serverTransport, sinfo, connection) - t.Logf("OnRead started") - defer t.Log("OnRead finished") - - go func() { - for { - s, err := trans.readStream(ctx) - t.Logf("OnRead read stream: %v, %v", s, err) - if err != nil { - if err == io.EOF { - return - } - t.Error(err) - } - ss := newServerStream(s) - go func(st streamx.ServerStream) { - defer func() { - // set trailer - err := st.(ServerStreamMeta).SetTrailer(streamx.Trailer{"key": "val"}) - test.Assert(t, err == nil, err) - - // send trailer - err = ss.(*serverStream).sendTrailer(ctx, nil) - test.Assert(t, err == nil, err) - atomic.AddInt32(&streamDone, -1) - }() - - // send header - err := st.(ServerStreamMeta).SendHeader(streamx.Header{"key": "val"}) - test.Assert(t, err == nil, err) - - // send data - for { - req := new(TestRequest) - err := st.RecvMsg(ctx, req) - if errors.Is(err, io.EOF) { - t.Logf("server stream closeRecv") - return - } - test.Assert(t, err == nil, err) - t.Logf("server recv msg: %v", req) - - res := req - err = st.SendMsg(ctx, res) - if errors.Is(err, io.EOF) { - return - } - test.Assert(t, err == nil, err) - t.Logf("server send msg: %v", res) - } - }(ss) - } - }() - - return context.WithValue(ctx, "trans", trans) - }), netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { - t.Logf("OnDisconnect started") - defer t.Logf("OnDisconnect finished") - - atomic.AddInt32(&connDone, -1) - })) - go func() { - err = svr.Serve(ln) - test.Assert(t, err == nil, err) - }() - defer svr.Shutdown(context.Background()) - test.WaitServerStart(addr) - - // Client - ctx := context.Background() - atomic.AddInt32(&connDone, 1) - conn, err := netpoll.DialConnection("tcp", addr, time.Second) - test.Assert(t, err == nil, err) - trans := newTransport(clientTransport, sinfo, conn) - - var wg sync.WaitGroup - for sid := 1; sid <= 1; sid++ { - wg.Add(1) - atomic.AddInt32(&streamDone, 1) - go func(sid int) { - defer wg.Done() - - // send header - sio, err := trans.newStreamIO(ctx, method, IntHeader{}, map[string]string{}) - test.Assert(t, err == nil, err) - - cs := newClientStream(sio.stream) - t.Logf("client stream[%d] created", sid) - - // recv header - hd, err := cs.Header() - test.Assert(t, err == nil, err) - test.Assert(t, hd["key"] == "val", hd) - t.Logf("client stream[%d] recv header=%v", sid, hd) - - // send and recv data - for i := 0; i < 3; i++ { - req := new(TestRequest) - req.A = 12345 - req.B = "hello" - res := new(TestResponse) - err = cs.SendMsg(ctx, req) - t.Logf("client stream[%d] send msg: %v", sid, req) - test.Assert(t, err == nil, err) - err = cs.RecvMsg(ctx, res) - t.Logf("client stream[%d] recv msg: %v", sid, res) - test.Assert(t, err == nil, err) - test.Assert(t, req.A == res.A, res) - test.Assert(t, req.B == res.B, res) - } - - // send trailer(trailer is stored in ctx) - err = cs.CloseSend(ctx) - test.Assert(t, err == nil, err) - t.Logf("client stream[%d] Close send", sid) - - // recv trailer - tl, err := cs.Trailer() - test.Assert(t, err == nil, err) - test.Assert(t, tl["key"] == "val", tl) - t.Logf("client stream[%d] recv trailer=%v", sid, tl) - }(sid) - } - wg.Wait() - for atomic.LoadInt32(&streamDone) != 0 { - t.Logf("wait all streams closed") - time.Sleep(time.Millisecond * 10) - } - - // Close conn - err = trans.Close() - test.Assert(t, err == nil, err) - err = ln.Close() - test.Assert(t, err == nil, err) - for atomic.LoadInt32(&connDone) != 0 { - time.Sleep(time.Millisecond * 10) - t.Logf("wait all connections closed") - } -} diff --git a/pkg/streamx/provider/ttstream/ttstream_client_test.go b/pkg/streamx/provider/ttstream/ttstream_client_test.go index df45e69d3b..6f918c1d56 100644 --- a/pkg/streamx/provider/ttstream/ttstream_client_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_client_test.go @@ -56,10 +56,14 @@ func testHeaderAndTrailer(t *testing.T, stream streamx.ClientStreamMetadata) { test.Assert(t, tl[trailerKey] == trailerVal, tl) } -func TestTTHeaderStreaming(t *testing.T) { +func TestMain(m *testing.M) { go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) }() + m.Run() +} + +func TestTTHeaderStreaming(t *testing.T) { var addr = test.GetLocalAddress() ln, err := netpoll.CreateListener("tcp", addr) test.Assert(t, err == nil, err) @@ -526,6 +530,58 @@ func TestTTHeaderStreamingRecvTimeout(t *testing.T) { test.Assert(t, err == nil, err) } +func TestTTHeaderStreamingServerGoroutines(t *testing.T) { + var addr = test.GetLocalAddress() + ln, _ := netpoll.CreateListener("tcp", addr) + defer ln.Close() + + // create server + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + // register streamingService as ttstreaam provider + sp, _ := ttstream.NewServerProvider(streamingServiceInfo) + _ = svr.RegisterService( + streamingServiceInfo, + new(streamingService), + streamxserver.WithProvider(sp), + ) + go func() { + _ = svr.Run() + }() + defer svr.Stop() + test.WaitServerStart(addr) + + cp, _ := ttstream.NewClientProvider( + streamingServiceInfo, + ttstream.WithClientLongConnPool(ttstream.LongConnConfig{MaxIdleTimeout: time.Second}), + ) + streamClient, _ := NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + streamxclient.WithProvider(cp), + ) + + oldNGs := runtime.NumGoroutine() + streams := 100 + streamList := make([]streamx.ServerStream, streams) + for i := 0; i < streams; i++ { + ctx := context.Background() + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + streamList[i] = bs + } + ngs := runtime.NumGoroutine() + test.Assert(t, ngs > streams, ngs) + for i := 0; i < streams; i++ { + streamList[i] = nil + } + streamList = nil + for ngs-oldNGs > 10 { + runtime.GC() + ngs = runtime.NumGoroutine() + time.Sleep(time.Millisecond * 100) + } +} + func BenchmarkTTHeaderStreaming(b *testing.B) { klog.SetLevel(klog.LevelWarn) var addr = test.GetLocalAddress() diff --git a/pkg/streamx/provider/ttstream/ttstream_server_test.go b/pkg/streamx/provider/ttstream/ttstream_server_test.go index 3eecb2b665..a0366af090 100644 --- a/pkg/streamx/provider/ttstream/ttstream_server_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_server_test.go @@ -156,7 +156,7 @@ func (si *streamingService) UnaryWithErr(ctx context.Context, req *Request) (*Re return nil, err } -func (si *streamingService) ClientStreamWithErr(ctx context.Context, stream ClientStreamingServer[Request, Response]) (res *Response, err error) { +func (si *streamingService) ClientStreamWithErr(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response]) (res *Response, err error) { req, err := stream.Recv(ctx) if err != nil { klog.Errorf("Server ClientStreamWithErr Recv failed, exception={%v}", err) @@ -167,13 +167,13 @@ func (si *streamingService) ClientStreamWithErr(ctx context.Context, stream Clie return nil, err } -func (si *streamingService) ServerStreamWithErr(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error { +func (si *streamingService) ServerStreamWithErr(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error { err := buildErr(req) klog.Infof("Server ServerStreamWithErr: req={%v} err={%v}", req, err) return err } -func (si *streamingService) BidiStreamWithErr(ctx context.Context, stream BidiStreamingServer[Request, Response]) error { +func (si *streamingService) BidiStreamWithErr(ctx context.Context, stream streamx.BidiStreamingServer[Request, Response]) error { req, err := stream.Recv(ctx) if err != nil { klog.Errorf("Server BidiStreamWithErr Recv failed, exception={%v}", err) From 9d05639b9c70a47594a76ac0fef80c63220aef02 Mon Sep 17 00:00:00 2001 From: Joway Date: Thu, 17 Oct 2024 17:42:27 +0800 Subject: [PATCH 09/34] chore(ttstream): fix ci (#1579) --- .github/workflows/tests.yml | 2 +- client/client.go | 1 - client/client_streamx.go | 16 + client/stream_test.go | 2 +- client/streamxclient/client.go | 16 + client/streamxclient/client_gen.go | 16 + client/streamxclient/client_option.go | 16 + .../streamxcallopt/call_option.go | 16 + internal/server/option.go | 3 +- pkg/remote/trans/streamx/server_handler.go | 13 +- pkg/rpcinfo/mocks_test.go | 4 + pkg/streamx/client_options.go | 16 + pkg/streamx/header_trailer.go | 22 +- .../provider/jsonrpc/jsonrpc_gen_test.go | 177 ----- .../provider/jsonrpc/jsonrpc_impl_test.go | 84 --- pkg/streamx/provider/jsonrpc/jsonrpc_test.go | 231 ------- pkg/streamx/provider/jsonrpc/protocol.go | 3 +- pkg/streamx/provider/jsonrpc/transport.go | 3 +- .../provider/jsonrpc/transport_test.go | 2 +- .../provider/ttstream/client_provier.go | 1 + .../provider/ttstream/client_trans_pool.go | 10 +- .../ttstream/client_trans_pool_longconn.go | 6 +- .../ttstream/client_trans_pool_mux.go | 6 +- .../ttstream/client_trans_pool_shortconn.go | 5 +- .../ttstream/container/object_pool.go | 33 +- .../provider/ttstream/container/pipe.go | 45 +- .../provider/ttstream/container/pipe_test.go | 13 +- .../provider/ttstream/container/stack.go | 2 +- .../provider/ttstream/container/stack_test.go | 20 +- pkg/streamx/provider/ttstream/exception.go | 16 + pkg/streamx/provider/ttstream/frame_test.go | 1 + .../provider/ttstream/meta_frame_handler.go | 4 +- .../provider/ttstream/server_provider.go | 41 +- pkg/streamx/provider/ttstream/stream.go | 30 +- .../ttstream/stream_header_trailer.go | 6 +- pkg/streamx/provider/ttstream/transport.go | 4 +- .../provider/ttstream/transport_buffer.go | 14 +- .../provider/ttstream/ttstream_client_test.go | 630 ------------------ .../provider/ttstream/ttstream_common_test.go | 61 -- .../provider/ttstream/ttstream_error_test.go | 174 ----- pkg/streamx/stream.go | 16 +- pkg/streamx/stream_middleware.go | 18 +- pkg/streamx/stream_middleware_internal.go | 16 + pkg/streamx/streamx_common_test.go | 95 +++ ...odec_test.go => streamx_gen_codec_test.go} | 6 +- ...ce_test.go => streamx_gen_service_test.go} | 61 +- ...r_test.go => streamx_user_service_test.go} | 28 +- pkg/streamx/streamx_user_test.go | 596 +++++++++++++++++ server/streamxserver/server.go | 16 + server/streamxserver/server_gen.go | 19 +- server/streamxserver/server_option.go | 22 +- tool/internal_pkg/tpl/streamx/client.go | 16 + .../tpl/streamx/handler.method.go | 16 + tool/internal_pkg/tpl/streamx/server.go | 16 + tool/internal_pkg/tpl/streamx/service.go | 16 + 55 files changed, 1187 insertions(+), 1535 deletions(-) delete mode 100644 pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go delete mode 100644 pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go delete mode 100644 pkg/streamx/provider/jsonrpc/jsonrpc_test.go delete mode 100644 pkg/streamx/provider/ttstream/ttstream_client_test.go delete mode 100644 pkg/streamx/provider/ttstream/ttstream_common_test.go delete mode 100644 pkg/streamx/provider/ttstream/ttstream_error_test.go create mode 100644 pkg/streamx/streamx_common_test.go rename pkg/streamx/{provider/ttstream/ttstream_gen_codec_test.go => streamx_gen_codec_test.go} (99%) rename pkg/streamx/{provider/ttstream/ttstream_gen_service_test.go => streamx_gen_service_test.go} (84%) rename pkg/streamx/{provider/ttstream/ttstream_server_test.go => streamx_user_service_test.go} (88%) create mode 100644 pkg/streamx/streamx_user_test.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 90b0c4a907..e71e772198 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,7 @@ jobs: go-version: ${{ matrix.go }} cache: false # don't use cache for self-hosted runners - name: Unit Test - run: go test -race ./... + run: go test -v -race ./... codegen-test: runs-on: ubuntu-latest diff --git a/client/client.go b/client/client.go index 03cc7cad49..adef12ac6b 100644 --- a/client/client.go +++ b/client/client.go @@ -74,7 +74,6 @@ type kClient struct { sEps endpoint.Endpoint // streamx - sxEps endpoint.Endpoint sxStreamMW streamx.StreamMiddleware sxStreamRecvMW streamx.StreamRecvMiddleware sxStreamSendMW streamx.StreamSendMiddleware diff --git a/client/client_streamx.go b/client/client_streamx.go index 60d1e767bf..6a0fcb769c 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package client import ( diff --git a/client/stream_test.go b/client/stream_test.go index 9c07647d2a..0dfee45a4c 100644 --- a/client/stream_test.go +++ b/client/stream_test.go @@ -104,7 +104,7 @@ func TestStreaming(t *testing.T) { cliInfo.ConnPool = connpool s, cr, _ := remotecli.NewStream(ctx, mockRPCInfo, new(mocks.MockCliTransHandler), cliInfo) stream := newStream( - s, cr, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, + s.(streaming.Stream), cr, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, func(stream streaming.Stream, message interface{}) (err error) { return stream.SendMsg(message) }, diff --git a/client/streamxclient/client.go b/client/streamxclient/client.go index 5bd52d8888..09f52707b3 100644 --- a/client/streamxclient/client.go +++ b/client/streamxclient/client.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamxclient import ( diff --git a/client/streamxclient/client_gen.go b/client/streamxclient/client_gen.go index eb1e68f530..001d03c149 100644 --- a/client/streamxclient/client_gen.go +++ b/client/streamxclient/client_gen.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamxclient import ( diff --git a/client/streamxclient/client_option.go b/client/streamxclient/client_option.go index bcc101dc63..0534b305a3 100644 --- a/client/streamxclient/client_option.go +++ b/client/streamxclient/client_option.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamxclient import ( diff --git a/client/streamxclient/streamxcallopt/call_option.go b/client/streamxclient/streamxcallopt/call_option.go index c4acd957fd..de48a90ac9 100644 --- a/client/streamxclient/streamxcallopt/call_option.go +++ b/client/streamxclient/streamxcallopt/call_option.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamxcallopt import ( diff --git a/internal/server/option.go b/internal/server/option.go index ddf0ab82db..5c2a8ffb46 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -23,6 +23,8 @@ import ( "os/signal" "syscall" + "github.com/cloudwego/localsession/backup" + "github.com/cloudwego/kitex/internal/configutil" "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/pkg/acl" @@ -43,7 +45,6 @@ import ( "github.com/cloudwego/kitex/pkg/stats" "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/pkg/utils" - "github.com/cloudwego/localsession/backup" ) func init() { diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index 84132bdcd7..ede3ec40bb 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -20,7 +20,6 @@ import ( "context" "errors" "io" - "log" "net" "runtime/debug" "time" @@ -62,8 +61,10 @@ func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remo }, nil } -var _ remote.ServerTransHandler = &svrTransHandler{} -var errProtocolNotMatch = errors.New("protocol not match") +var ( + _ remote.ServerTransHandler = &svrTransHandler{} + errProtocolNotMatch = errors.New("protocol not match") +) type svrTransHandler struct { opt *remote.ServerOption @@ -106,7 +107,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { streamWorkerPool.GoCtx(ctx, func() { err := t.OnStream(nctx, conn, ss) if err != nil && !errors.Is(err, io.EOF) { - klog.CtxErrorf(ctx, "KITEX: stream ReadStream failed: err=%v", nerr) + klog.CtxErrorf(ctx, "KITEX: stream ReadStream failed: err=%v", err) } }) } @@ -139,13 +140,9 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil { _ = mutableTo.SetMethod(ss.Method()) } - //_ = rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(transport.JSONRPC) ctx = t.startTracer(ctx, ri) defer func() { - if err != nil { - log.Println("OnStream failed: ", err) - } panicErr := recover() if panicErr != nil { if conn != nil { diff --git a/pkg/rpcinfo/mocks_test.go b/pkg/rpcinfo/mocks_test.go index 781b6b360f..c0033bbd7e 100644 --- a/pkg/rpcinfo/mocks_test.go +++ b/pkg/rpcinfo/mocks_test.go @@ -90,6 +90,10 @@ func (m *MockRPCConfig) TransportProtocol() (r transport.Protocol) { return } +func (m *MockRPCConfig) StreamRecvTimeout() time.Duration { + return time.Duration(0) +} + type MockRPCStats struct{} func (m *MockRPCStats) Record(context.Context, stats.Event, stats.Status, string) {} diff --git a/pkg/streamx/client_options.go b/pkg/streamx/client_options.go index 9beb142357..935ad0d61b 100644 --- a/pkg/streamx/client_options.go +++ b/pkg/streamx/client_options.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( diff --git a/pkg/streamx/header_trailer.go b/pkg/streamx/header_trailer.go index 449bf99847..50c2bbeaef 100644 --- a/pkg/streamx/header_trailer.go +++ b/pkg/streamx/header_trailer.go @@ -1,4 +1,22 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx -type Header map[string]string -type Trailer map[string]string +type ( + Header map[string]string + Trailer map[string]string +) diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go deleted file mode 100644 index afae42d52a..0000000000 --- a/pkg/streamx/provider/jsonrpc/jsonrpc_gen_test.go +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc_test - -import ( - "context" - - "github.com/cloudwego/kitex/client/streamxclient" - "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" - "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/pkg/streamx/provider/jsonrpc" - "github.com/cloudwego/kitex/server/streamxserver" -) - -// === gen code === - -type ClientStreamingServer[Req, Res any] streamx.ClientStreamingServer[Req, Res] -type ServerStreamingServer[Res any] streamx.ServerStreamingServer[Res] -type BidiStreamingServer[Req, Res any] streamx.BidiStreamingServer[Req, Res] -type ClientStreamingClient[Req, Res any] streamx.ClientStreamingClient[Req, Res] -type ServerStreamingClient[Res any] streamx.ServerStreamingClient[Res] -type BidiStreamingClient[Req, Res any] streamx.BidiStreamingClient[Req, Res] - -var serviceInfo = &serviceinfo.ServiceInfo{ - ServiceName: "a.b.c", - Methods: map[string]serviceinfo.MethodInfo{ - "Unary": serviceinfo.NewMethodInfo( - func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingUnary, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) - }, - nil, - nil, - false, - serviceinfo.WithStreamingMode(serviceinfo.StreamingUnary), - ), - "ClientStream": serviceinfo.NewMethodInfo( - func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingClient, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) - }, - nil, - nil, - false, - serviceinfo.WithStreamingMode(serviceinfo.StreamingClient), - ), - "ServerStream": serviceinfo.NewMethodInfo( - func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingServer, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) - }, - nil, - nil, - false, - serviceinfo.WithStreamingMode(serviceinfo.StreamingServer), - ), - "BidiStream": serviceinfo.NewMethodInfo( - func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) - }, - nil, - nil, - false, - serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), - ), - }, - Extra: map[string]interface{}{"streaming": true, "streamx": true}, -} - -func NewClient(destService string, opts ...streamxclient.Option) (ClientInterface, error) { - var options []streamxclient.Option - options = append(options, streamxclient.WithDestService(destService)) - options = append(options, opts...) - cp, err := jsonrpc.NewClientProvider(serviceInfo) - if err != nil { - return nil, err - } - options = append(options, streamxclient.WithProvider(cp)) - cli, err := streamxclient.NewClient(serviceInfo, options...) - if err != nil { - return nil, err - } - kc := &kClient{Client: cli} - return kc, nil -} - -func NewServer(handler ServerInterface, opts ...streamxserver.Option) (streamxserver.Server, error) { - var options []streamxserver.Option - options = append(options, opts...) - sp, err := jsonrpc.NewServerProvider(serviceInfo) - if err != nil { - return nil, err - } - svr := streamxserver.NewServer(options...) - if err := svr.RegisterService(serviceInfo, handler, streamxserver.WithProvider(sp)); err != nil { - return nil, err - } - return svr, nil -} - -type Request struct { - Type int32 `json:"Type"` - Message string `json:"Message"` -} - -type Response struct { - Type int32 `json:"Type"` - Message string `json:"Message"` -} - -type ServerInterface interface { - Unary(ctx context.Context, req *Request) (*Response, error) - ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (*Response, error) - ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error - BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error -} - -type ClientInterface interface { - Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) - ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream ClientStreamingClient[Request, Response], err error) - ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream ServerStreamingClient[Response], err error) - BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream BidiStreamingClient[Request, Response], err error) -} - -// --- Define Client Implementation --- - -var _ ClientInterface = (*kClient)(nil) - -type kClient struct { - streamxclient.Client -} - -func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { - res := new(Response) - _, err := streamxclient.InvokeStream[Request, Response]( - ctx, c.Client, serviceinfo.StreamingUnary, "Unary", req, res, callOptions...) - if err != nil { - return nil, err - } - return res, nil -} - -func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream ClientStreamingClient[Request, Response], err error) { - return streamxclient.InvokeStream[Request, Response]( - ctx, c.Client, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) -} - -func (c *kClient) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream ServerStreamingClient[Response], err error) { - return streamxclient.InvokeStream[Request, Response]( - ctx, c.Client, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) -} - -func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream BidiStreamingClient[Request, Response], err error) { - return streamxclient.InvokeStream[Request, Response]( - ctx, c.Client, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) -} diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go deleted file mode 100644 index c3ba1b7644..0000000000 --- a/pkg/streamx/provider/jsonrpc/jsonrpc_impl_test.go +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc_test - -import ( - "context" - "io" - "log" -) - -type serviceImpl struct{} - -func (si *serviceImpl) Unary(ctx context.Context, req *Request) (*Response, error) { - resp := &Response{Message: req.Message} - log.Printf("Server Unary: req={%v} resp={%v}", req, resp) - return resp, nil -} - -func (si *serviceImpl) ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (res *Response, err error) { - var msg string - defer log.Printf("Server ClientStream end") - for { - req, err := stream.Recv(ctx) - if err == io.EOF { - res = new(Response) - res.Message = msg - return res, nil - } - if err != nil { - return nil, err - } - msg = req.Message - log.Printf("Server ClientStream: req={%v}", req) - } -} - -func (si *serviceImpl) ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error { - log.Printf("Server ServerStream: req={%v}", req) - for i := 0; i < 3; i++ { - resp := new(Response) - resp.Type = int32(i) - resp.Message = req.Message - err := stream.Send(ctx, resp) - if err != nil { - return err - } - log.Printf("Server ServerStream: resp={%v}", resp) - } - return nil -} - -func (si *serviceImpl) BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error { - for { - req, err := stream.Recv(ctx) - if err == io.EOF { - return nil - } - if err != nil { - return err - } - - resp := new(Response) - resp.Message = req.Message - err = stream.Send(ctx, resp) - if err != nil { - return err - } - log.Printf("Server BidiStream: req={%v} resp={%v}", req, resp) - } -} diff --git a/pkg/streamx/provider/jsonrpc/jsonrpc_test.go b/pkg/streamx/provider/jsonrpc/jsonrpc_test.go deleted file mode 100644 index ecfa8c3f6e..0000000000 --- a/pkg/streamx/provider/jsonrpc/jsonrpc_test.go +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc_test - -import ( - "context" - "errors" - "io" - "log" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/cloudwego/kitex/client/streamxclient" - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/pkg/streamx/provider/jsonrpc" - "github.com/cloudwego/kitex/server/streamxserver" - "github.com/cloudwego/netpoll" -) - -func TestJSONRPC(t *testing.T) { - var addr = test.GetLocalAddress() - ln, err := netpoll.CreateListener("tcp", addr) - test.Assert(t, err == nil, err) - - // create server - var serverStreamCount int32 - waitServerStreamDone := func() { - for atomic.LoadInt32(&serverStreamCount) != 0 { - t.Logf("waitServerStreamDone: %d", atomic.LoadInt32(&serverStreamCount)) - time.Sleep(time.Millisecond * 100) - } - } - methodCount := map[string]int{} - sp, err := jsonrpc.NewServerProvider(serviceInfo) - test.Assert(t, err == nil, err) - svr := streamxserver.NewServer(streamxserver.WithListener(ln)) - err = svr.RegisterService(serviceInfo, new(serviceImpl), - streamxserver.WithProvider(sp), - streamxserver.WithStreamMiddleware( - // middleware example: server streaming mode - func(next streamx.StreamEndpoint) streamx.StreamEndpoint { - return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { - log.Printf("Server middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", - reqArgs.Req(), resArgs.Res(), streamArgs) - test.Assert(t, streamArgs.Stream() != nil) - - switch streamArgs.Stream().Mode() { - case streamx.StreamingUnary: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() != nil) - case streamx.StreamingClient: - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() != nil) - case streamx.StreamingServer: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - case streamx.StreamingBidirectional: - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - } - test.Assert(t, err == nil, err) - methodCount[streamArgs.Stream().Method()]++ - - log.Printf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", - reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) - atomic.AddInt32(&serverStreamCount, 1) - return nil - } - }, - ), - ) - test.Assert(t, err == nil, err) - go func() { - err := svr.Run() - test.Assert(t, err == nil, err) - }() - defer svr.Stop() - time.Sleep(time.Millisecond * 100) - - // create client - ctx := context.Background() - cli, err := NewClient( - "a.b.c", - streamxclient.WithHostPorts(addr), - streamxclient.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { - return func(ctx context.Context, stream streamx.Stream, res any) (err error) { - return next(ctx, stream, res) - } - }), - streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { - return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { - log.Printf("Client middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", - reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) - err = next(ctx, streamArgs, reqArgs, resArgs) - log.Printf("Client middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", - reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) - - test.Assert(t, streamArgs.Stream() != nil) - switch streamArgs.Stream().Mode() { - case streamx.StreamingUnary: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() != nil) - case streamx.StreamingClient: - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - case streamx.StreamingServer: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - case streamx.StreamingBidirectional: - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - } - test.Assert(t, err == nil, err) - return err - } - }), - ) - test.Assert(t, err == nil, err) - - t.Logf("=== Unary ===") - req := new(Request) - req.Message = "Unary" - res, err := cli.Unary(ctx, req) - test.Assert(t, err == nil, err) - test.Assert(t, req.Message == res.Message, res.Message) - atomic.AddInt32(&serverStreamCount, -1) - waitServerStreamDone() - - // client stream - t.Logf("=== ClientStream ===") - cs, err := cli.ClientStream(ctx) - test.Assert(t, err == nil, err) - for i := 0; i < 3; i++ { - req := new(Request) - req.Type = int32(i) - req.Message = "ClientStream" - err = cs.Send(ctx, req) - test.Assert(t, err == nil, err) - } - res, err = cs.CloseAndRecv(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, res.Message == "ClientStream", res.Message) - t.Logf("Client ClientStream CloseAndRecv: %v", res) - atomic.AddInt32(&serverStreamCount, -1) - waitServerStreamDone() - - // server stream - t.Logf("=== ServerStream ===") - req = new(Request) - req.Message = "ServerStream" - ss, err := cli.ServerStream(ctx, req) - test.Assert(t, err == nil, err) - for { - res, err := ss.Recv(ctx) - if errors.Is(err, io.EOF) { - break - } - test.Assert(t, err == nil, err) - t.Logf("Client ServerStream recv: %v", res) - } - //err = ss.CloseSend(ctx) - //test.Assert(t, err == nil, err) - atomic.AddInt32(&serverStreamCount, -1) - waitServerStreamDone() - - // bidi stream - t.Logf("=== BidiStream ===") - bs, err := cli.BidiStream(ctx) - test.Assert(t, err == nil, err) - round := 5 - msg := "BidiStream" - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - for i := 0; i < round; i++ { - req := new(Request) - req.Message = msg - err := bs.Send(ctx, req) - test.Assert(t, err == nil, err) - } - err = bs.CloseSend(ctx) - test.Assert(t, err == nil, err) - }() - go func() { - defer wg.Done() - i := 0 - for { - res, err := bs.Recv(ctx) - if errors.Is(err, io.EOF) { - break - } - i++ - test.Assert(t, err == nil, err) - test.Assert(t, msg == res.Message, res.Message) - } - test.Assert(t, i == round, i) - }() - wg.Wait() - atomic.AddInt32(&serverStreamCount, -1) - waitServerStreamDone() -} diff --git a/pkg/streamx/provider/jsonrpc/protocol.go b/pkg/streamx/provider/jsonrpc/protocol.go index 3d8654ea29..cbac87b90d 100644 --- a/pkg/streamx/provider/jsonrpc/protocol.go +++ b/pkg/streamx/provider/jsonrpc/protocol.go @@ -78,7 +78,7 @@ type Frame struct { payload []byte } -func newFrame(typ int, sid int, service, method string, payload []byte) Frame { +func newFrame(typ, sid int, service, method string, payload []byte) Frame { return Frame{ typ: typ, sid: sid, @@ -115,6 +115,7 @@ func EncodeFrame(writer io.Writer, frame Frame) (err error) { offset += len(frame.method) copy(data[offset:offset+len(frame.payload)], frame.payload) offset += len(frame.payload) + _ = offset idx := 0 for idx < len(data) { diff --git a/pkg/streamx/provider/jsonrpc/transport.go b/pkg/streamx/provider/jsonrpc/transport.go index e761f89083..399fe3ba9d 100644 --- a/pkg/streamx/provider/jsonrpc/transport.go +++ b/pkg/streamx/provider/jsonrpc/transport.go @@ -24,9 +24,10 @@ import ( "sync" "sync/atomic" + "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/netpoll" ) type transport struct { diff --git a/pkg/streamx/provider/jsonrpc/transport_test.go b/pkg/streamx/provider/jsonrpc/transport_test.go index 0327eb12f6..187e15e7c8 100644 --- a/pkg/streamx/provider/jsonrpc/transport_test.go +++ b/pkg/streamx/provider/jsonrpc/transport_test.go @@ -68,7 +68,7 @@ func TestTransport(t *testing.T) { Extra: map[string]interface{}{"streaming": true}, } - var addr = test.GetLocalAddress() + addr := test.GetLocalAddress() ln, err := net.Listen("tcp", addr) test.Assert(t, err == nil, err) diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index 5c82721a5b..f79e760151 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -23,6 +23,7 @@ import ( "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/rpcinfo" diff --git a/pkg/streamx/provider/ttstream/client_trans_pool.go b/pkg/streamx/provider/ttstream/client_trans_pool.go index f694b9db3c..7ec587d93d 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool.go @@ -16,9 +16,15 @@ package ttstream -import "github.com/cloudwego/kitex/pkg/serviceinfo" +import ( + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +var dialer = netpoll.NewDialer() type transPool interface { - Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) + Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) Put(trans *transport) } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go index 3b3dff2750..bae4f6ec6e 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go @@ -19,8 +19,6 @@ package ttstream import ( "time" - "github.com/cloudwego/netpoll" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" ) @@ -48,7 +46,7 @@ type longConnTransPool struct { config LongConnConfig } -func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) { +func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) { for { o := c.transPool.Pop(addr) if o == nil { @@ -61,7 +59,7 @@ func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, } // create new connection - conn, err := netpoll.DialConnection(network, addr, time.Second) + conn, err := dialer.DialConnection(network, addr, time.Second) if err != nil { return nil, err } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go index dfdf534dcf..d0d9fc8a60 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go @@ -44,7 +44,7 @@ func newMuxTransList(size int) *muxTransList { return tl } -func (tl *muxTransList) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (*transport, error) { +func (tl *muxTransList) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { idx := atomic.AddUint32(&tl.cursor, 1) % uint32(tl.size) tl.L.RLock() trans := tl.transports[idx] @@ -53,7 +53,7 @@ func (tl *muxTransList) Get(sinfo *serviceinfo.ServiceInfo, network string, addr return trans, nil } - conn, err := netpoll.DialConnection(network, addr, time.Second) + conn, err := dialer.DialConnection(network, addr, time.Second) if err != nil { return nil, err } @@ -85,7 +85,7 @@ type muxTransPool struct { sflight singleflight.Group } -func (m *muxTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) { +func (m *muxTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) { v, ok := m.pool.Load(addr) if ok { return v.(*muxTransList).Get(sinfo, network, addr) diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go index 37c803b3c7..e5eb081520 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go @@ -20,7 +20,6 @@ import ( "time" "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/netpoll" ) func newShortConnTransPool() transPool { @@ -29,9 +28,9 @@ func newShortConnTransPool() transPool { type shortConnTransPool struct{} -func (c *shortConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (*transport, error) { +func (c *shortConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { // create new connection - conn, err := netpoll.DialConnection(network, addr, time.Second) + conn, err := dialer.DialConnection(network, addr, time.Second) if err != nil { return nil, err } diff --git a/pkg/streamx/provider/ttstream/container/object_pool.go b/pkg/streamx/provider/ttstream/container/object_pool.go index 26fdc26978..51658e706e 100644 --- a/pkg/streamx/provider/ttstream/container/object_pool.go +++ b/pkg/streamx/provider/ttstream/container/object_pool.go @@ -1,11 +1,25 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package container import ( "sync" "sync/atomic" "time" - - "github.com/cloudwego/kitex/pkg/klog" ) type Object interface { @@ -78,14 +92,9 @@ func (s *ObjectPool) cleaning() { cleanInternal = time.Second * 10 } // clean objects - for key, stk := range s.objects { + for _, stk := range s.objects { deleted := 0 - var oldest *time.Time - klog.Infof("object[%s] pool cleaning %d objects", key, stk.Size()) - stk.RangeDelete(func(o objectItem) (deleteNode bool, continueRange bool) { - if oldest == nil { - oldest = &o.lastActive - } + stk.RangeDelete(func(o objectItem) (deleteNode, continueRange bool) { if o.object == nil { deleted++ return true, true @@ -97,13 +106,9 @@ func (s *ObjectPool) cleaning() { return false, false } deleted++ - err := o.object.Close() - klog.Infof("object is invalid: lastActive=%s, closedErr=%v", o.lastActive.String(), err) + _ = o.object.Close() return true, true }) - if oldest != nil { - klog.Infof("object[%s] pool deleted %d objects, oldest=%s", key, deleted, oldest.String()) - } } s.L.Unlock() } diff --git a/pkg/streamx/provider/ttstream/container/pipe.go b/pkg/streamx/provider/ttstream/container/pipe.go index 6847002003..a2c8254354 100644 --- a/pkg/streamx/provider/ttstream/container/pipe.go +++ b/pkg/streamx/provider/ttstream/container/pipe.go @@ -32,12 +32,14 @@ const ( pipeStateCanceled pipeState = 3 ) -var ErrPipeEOF = io.EOF -var ErrPipeCanceled = fmt.Errorf("pipe canceled") -var stateErrors map[pipeState]error = map[pipeState]error{ - pipeStateClosed: ErrPipeEOF, - pipeStateCanceled: ErrPipeCanceled, -} +var ( + ErrPipeEOF = io.EOF + ErrPipeCanceled = fmt.Errorf("pipe canceled") + stateErrors = map[pipeState]error{ + pipeStateClosed: ErrPipeEOF, + pipeStateCanceled: ErrPipeCanceled, + } +) // Pipe implement a queue that never block on Write but block on Read if there is nothing to read type Pipe[Item any] struct { @@ -49,14 +51,13 @@ type Pipe[Item any] struct { func NewPipe[Item any]() *Pipe[Item] { p := new(Pipe[Item]) p.queue = NewQueue[Item]() - p.trigger = make(chan struct{}) + p.trigger = make(chan struct{}, 1) return p } // Read will block if there is nothing to read -func (p *Pipe[Item]) Read(ctx context.Context, items []Item) (int, error) { +func (p *Pipe[Item]) Read(ctx context.Context, items []Item) (n int, err error) { READ: - var n int for i := 0; i < len(items); i++ { val, ok := p.queue.Get() if !ok { @@ -82,22 +83,22 @@ READ: } if p.queue.Size() == 0 { - err := stateErrors[atomic.LoadInt32(&p.state)] + err = stateErrors[atomic.LoadInt32(&p.state)] if err != nil { return 0, err } - return 0, fmt.Errorf("unknown err") } goto READ } } -func (p *Pipe[Item]) Write(ctx context.Context, items ...Item) error { +func (p *Pipe[Item]) Write(ctx context.Context, items ...Item) (err error) { if !atomic.CompareAndSwapInt32(&p.state, pipeStateInactive, pipeStateActive) && atomic.LoadInt32(&p.state) != pipeStateActive { - err := stateErrors[atomic.LoadInt32(&p.state)] + err = stateErrors[atomic.LoadInt32(&p.state)] if err != nil { return err } + // never happen error return fmt.Errorf("unknown state error") } @@ -113,19 +114,21 @@ func (p *Pipe[Item]) Write(ctx context.Context, items ...Item) error { } func (p *Pipe[Item]) Close() { - select { - case <-p.trigger: - default: + if atomic.LoadInt32(&p.state) != pipeStateClosed { atomic.StoreInt32(&p.state, pipeStateClosed) - close(p.trigger) + select { + case p.trigger <- struct{}{}: + default: + } } } func (p *Pipe[Item]) Cancel() { - select { - case <-p.trigger: - default: + if atomic.LoadInt32(&p.state) != pipeStateCanceled { atomic.StoreInt32(&p.state, pipeStateCanceled) - close(p.trigger) + select { + case p.trigger <- struct{}{}: + default: + } } } diff --git a/pkg/streamx/provider/ttstream/container/pipe_test.go b/pkg/streamx/provider/ttstream/container/pipe_test.go index 761589d30b..7475fd13d8 100644 --- a/pkg/streamx/provider/ttstream/container/pipe_test.go +++ b/pkg/streamx/provider/ttstream/container/pipe_test.go @@ -18,6 +18,7 @@ package container import ( "context" + "io" "sync" "testing" ) @@ -34,6 +35,9 @@ func TestPipeline(t *testing.T) { for { n, err := pipe.Read(ctx, items) if err != nil { + if err != io.EOF { + t.Error(err) + } return } for i := 0; i < n; i++ { @@ -44,12 +48,17 @@ func TestPipeline(t *testing.T) { round := 10000 itemsPerRound := []int{1, 1, 1, 1, 1} for i := 0; i < round; i++ { - _ = pipe.Write(ctx, itemsPerRound...) + err := pipe.Write(ctx, itemsPerRound...) + if err != nil { + t.Fatal(err) + } } + t.Logf("Pipe closing") pipe.Close() + t.Logf("Pipe closed") wg.Wait() if recv != len(itemsPerRound)*round { - t.Fatalf("expect %d items, got %d", len(itemsPerRound)*round, recv) + t.Fatalf("Pipe expect %d items, got %d", len(itemsPerRound)*round, recv) } } diff --git a/pkg/streamx/provider/ttstream/container/stack.go b/pkg/streamx/provider/ttstream/container/stack.go index 7083fc27e0..e9a0a883c1 100644 --- a/pkg/streamx/provider/ttstream/container/stack.go +++ b/pkg/streamx/provider/ttstream/container/stack.go @@ -40,7 +40,7 @@ func (s *Stack[ValueType]) Size() (size int) { } // RangeDelete range from the stack bottom -func (s *Stack[ValueType]) RangeDelete(checking func(v ValueType) (deleteNode bool, continueRange bool)) { +func (s *Stack[ValueType]) RangeDelete(checking func(v ValueType) (deleteNode, continueRange bool)) { // Stop the world! s.L.Lock() // range from the stack bottom(oldest item) diff --git a/pkg/streamx/provider/ttstream/container/stack_test.go b/pkg/streamx/provider/ttstream/container/stack_test.go index 420bd636b5..fe08d64ce9 100644 --- a/pkg/streamx/provider/ttstream/container/stack_test.go +++ b/pkg/streamx/provider/ttstream/container/stack_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package container import ( @@ -75,12 +91,12 @@ func TestStackRangeDelete(t *testing.T) { for i := 1; i <= round; i++ { stk.Push(i) } - stk.RangeDelete(func(v int) (deleteNode bool, continueRange bool) { + stk.RangeDelete(func(v int) (deleteNode, continueRange bool) { return v%2 == 0, true }) test.Assert(t, stk.Size() == round/2, stk.Size()) size := 0 - stk.RangeDelete(func(v int) (deleteNode bool, continueRange bool) { + stk.RangeDelete(func(v int) (deleteNode, continueRange bool) { size++ return false, true }) diff --git a/pkg/streamx/provider/ttstream/exception.go b/pkg/streamx/provider/ttstream/exception.go index 5c650784b2..17f51b157d 100644 --- a/pkg/streamx/provider/ttstream/exception.go +++ b/pkg/streamx/provider/ttstream/exception.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ttstream type tException interface { diff --git a/pkg/streamx/provider/ttstream/frame_test.go b/pkg/streamx/provider/ttstream/frame_test.go index 19283d18c5..6bba79a251 100644 --- a/pkg/streamx/provider/ttstream/frame_test.go +++ b/pkg/streamx/provider/ttstream/frame_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/kitex/internal/test" ) diff --git a/pkg/streamx/provider/ttstream/meta_frame_handler.go b/pkg/streamx/provider/ttstream/meta_frame_handler.go index 35c3f5f67c..e5dd6a54dc 100644 --- a/pkg/streamx/provider/ttstream/meta_frame_handler.go +++ b/pkg/streamx/provider/ttstream/meta_frame_handler.go @@ -25,7 +25,7 @@ import ( type StreamMeta interface { Meta() map[string]string GetMeta(k string) (string, bool) - SetMeta(k string, v string, kvs ...string) + SetMeta(k, v string, kvs ...string) } type MetaFrameHandler interface { @@ -60,7 +60,7 @@ func (s *streamMeta) GetMeta(k string) (string, bool) { return v, ok } -func (s *streamMeta) SetMeta(k string, v string, kvs ...string) { +func (s *streamMeta) SetMeta(k, v string, kvs ...string) { s.sync.RLock() s.data[k] = v for i := 0; i < len(kvs); i += 2 { diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go index 7c9add2f80..5b62dea78f 100644 --- a/pkg/streamx/provider/ttstream/server_provider.go +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -34,8 +34,10 @@ import ( "github.com/cloudwego/kitex/pkg/utils" ) -type serverTransCtxKey struct{} -type serverStreamCancelCtxKey struct{} +type ( + serverTransCtxKey struct{} + serverStreamCancelCtxKey struct{} +) func NewServerProvider(sinfo *serviceinfo.ServiceInfo, opts ...ServerProviderOption) (streamx.ServerProvider, error) { sp := new(serverProvider) @@ -96,24 +98,33 @@ func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStr sst := ss.(*serverStream) var exception tException if err != nil { - switch err.(type) { - case tException: - exception = err.(tException) + switch terr := err.(type) { case kerrors.BizStatusErrorIface: - bizErr := err.(kerrors.BizStatusErrorIface) - sst.appendTrailer( - "biz-status", strconv.Itoa(int(bizErr.BizStatusCode())), - "biz-message", bizErr.BizMessage(), - ) - if bizErr.BizExtra() != nil { - extra, _ := utils.Map2JSONStr(bizErr.BizExtra()) - sst.appendTrailer("biz-extra", extra) + bizStatus := strconv.Itoa(int(terr.BizStatusCode())) + bizMsg := terr.BizMessage() + if terr.BizExtra() == nil { + err = sst.writeTrailer(streamx.Trailer{ + "biz-status": bizStatus, + "biz-message": bizMsg, + }) + } else { + bizExtra, _ := utils.Map2JSONStr(terr.BizExtra()) + err = sst.writeTrailer(streamx.Trailer{ + "biz-status": bizStatus, + "biz-message": bizMsg, + "biz-extra": bizExtra, + }) } + if err != nil { + return nil, err + } + case tException: + exception = terr default: - exception = thrift.NewApplicationException(remote.InternalError, err.Error()) + exception = thrift.NewApplicationException(remote.InternalError, terr.Error()) } } - if err := sst.close(exception); err != nil { + if err = sst.close(exception); err != nil { return nil, err } diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index f426ed3bf7..3bdb35dc25 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -129,14 +129,6 @@ func (s *stream) readHeader(hd streamx.Header) (err error) { return nil } -// setHeader use the hd as the underlying header -func (s *stream) setHeader(hd streamx.Header) { - if hd != nil { - s.wheader = hd - } - return -} - // writeHeader copy kvs into s.wheader func (s *stream) writeHeader(hd streamx.Header) error { if s.wheader == nil { @@ -205,21 +197,6 @@ func (s *stream) writeTrailer(tl streamx.Trailer) (err error) { return nil } -func (s *stream) appendTrailer(kvs ...string) (err error) { - if len(kvs)%2 != 0 { - return fmt.Errorf("got the odd number of input kvs for Trailer: %d", len(kvs)) - } - var key string - for i, str := range kvs { - if i%2 == 0 { - key = str - continue - } - s.wtrailer[key] = str - } - return nil -} - func (s *stream) sendTrailer(ctx context.Context, ex tException) (err error) { if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { return nil @@ -229,15 +206,10 @@ func (s *stream) sendTrailer(ctx context.Context, ex tException) (err error) { if wtrailer == nil { return fmt.Errorf("stream trailer already sent") } - klog.Debugf("transport[%d]-stream[%d] send trialer", s.trans.kind, s.sid) + klog.Debugf("transport[%d]-stream[%d] send trailer", s.trans.kind, s.sid) return s.trans.streamCloseSend(s.sid, s.method, wtrailer, ex) } -func (s *stream) finished() bool { - return atomic.LoadInt32(&s.peerEOF) == 1 && - atomic.LoadInt32(&s.selfEOF) == 1 -} - func (s *stream) setRecvTimeout(timeout time.Duration) { if timeout <= 0 { return diff --git a/pkg/streamx/provider/ttstream/stream_header_trailer.go b/pkg/streamx/provider/ttstream/stream_header_trailer.go index beedb24f41..611c12a0ca 100644 --- a/pkg/streamx/provider/ttstream/stream_header_trailer.go +++ b/pkg/streamx/provider/ttstream/stream_header_trailer.go @@ -22,8 +22,10 @@ import ( "github.com/cloudwego/kitex/pkg/streamx" ) -var _ ClientStreamMeta = (*clientStream)(nil) -var _ ServerStreamMeta = (*serverStream)(nil) +var ( + _ ClientStreamMeta = (*clientStream)(nil) + _ ServerStreamMeta = (*serverStream)(nil) +) func (s *clientStream) Header() (streamx.Header, error) { sig := <-s.headerSig diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index e7686bfdf3..dc97d62f09 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -222,6 +222,7 @@ func (t *transport) loopWrite() error { } for i := 0; i < n; i++ { fr := fcache[i] + klog.Debugf("transport[%d] EncodeFrame: fr=%v", t.kind, fr) if err = EncodeFrame(context.Background(), writer, fr); err != nil { return err } @@ -332,7 +333,8 @@ var clientStreamID int32 // it's typically used by client side // newStreamIO is concurrency safe func (t *transport) newStreamIO( - ctx context.Context, method string, intHeader IntHeader, strHeader streamx.Header) (*streamIO, error) { + ctx context.Context, method string, intHeader IntHeader, strHeader streamx.Header, +) (*streamIO, error) { if t.kind != clientTransport { return nil, fmt.Errorf("transport already be used as other kind") } diff --git a/pkg/streamx/provider/ttstream/transport_buffer.go b/pkg/streamx/provider/ttstream/transport_buffer.go index a51a47c64d..36620fac46 100644 --- a/pkg/streamx/provider/ttstream/transport_buffer.go +++ b/pkg/streamx/provider/ttstream/transport_buffer.go @@ -24,12 +24,16 @@ import ( "github.com/cloudwego/netpoll" ) -var _ bufiox.Reader = (*readerBuffer)(nil) -var _ bufiox.Writer = (*writerBuffer)(nil) -var _ gopkgthrift.NocopyWriter = (*writerBuffer)(nil) +var ( + _ bufiox.Reader = (*readerBuffer)(nil) + _ bufiox.Writer = (*writerBuffer)(nil) + _ gopkgthrift.NocopyWriter = (*writerBuffer)(nil) +) -var readerBufferPool sync.Pool -var writerBufferPool sync.Pool +var ( + readerBufferPool sync.Pool + writerBufferPool sync.Pool +) func newReaderBuffer(reader netpoll.Reader) (rb *readerBuffer) { if v := readerBufferPool.Get(); v != nil { diff --git a/pkg/streamx/provider/ttstream/ttstream_client_test.go b/pkg/streamx/provider/ttstream/ttstream_client_test.go deleted file mode 100644 index 6f918c1d56..0000000000 --- a/pkg/streamx/provider/ttstream/ttstream_client_test.go +++ /dev/null @@ -1,630 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream_test - -import ( - "context" - "errors" - "io" - "log" - "net/http" - _ "net/http/pprof" - "runtime" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/cloudwego/netpoll" - - "github.com/cloudwego/kitex/client" - "github.com/cloudwego/kitex/client/streamxclient" - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/klog" - "github.com/cloudwego/kitex/pkg/remote/codec/thrift" - "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" - "github.com/cloudwego/kitex/server" - "github.com/cloudwego/kitex/server/streamxserver" - "github.com/cloudwego/kitex/transport" -) - -func init() { - klog.SetLevel(klog.LevelDebug) -} - -func testHeaderAndTrailer(t *testing.T, stream streamx.ClientStreamMetadata) { - hd, err := stream.Header() - test.Assert(t, err == nil, err) - test.Assert(t, hd[headerKey] == headerVal, hd) - tl, err := stream.Trailer() - test.Assert(t, err == nil, err) - test.Assert(t, tl[trailerKey] == trailerVal, tl) -} - -func TestMain(m *testing.M) { - go func() { - log.Println(http.ListenAndServe("localhost:6060", nil)) - }() - m.Run() -} - -func TestTTHeaderStreaming(t *testing.T) { - var addr = test.GetLocalAddress() - ln, err := netpoll.CreateListener("tcp", addr) - test.Assert(t, err == nil, err) - defer ln.Close() - - // create server - var serverStreamCount int32 - waitServerStreamDone := func() { - for atomic.LoadInt32(&serverStreamCount) != 0 { - t.Logf("waitServerStreamDone: %d", atomic.LoadInt32(&serverStreamCount)) - time.Sleep(time.Millisecond * 100) - } - } - var serverRecvCount int32 - var serverSendCount int32 - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) - // register pingpong service - err = svr.RegisterService(pingpongServiceInfo, new(pingpongService)) - test.Assert(t, err == nil, err) - // register streamingService as ttstreaam provider - sp, err := ttstream.NewServerProvider(streamingServiceInfo) - test.Assert(t, err == nil, err) - err = svr.RegisterService( - streamingServiceInfo, - new(streamingService), - streamxserver.WithProvider(sp), - streamxserver.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { - return func(ctx context.Context, stream streamx.Stream, res any) (err error) { - err = next(ctx, stream, res) - if err == nil { - atomic.AddInt32(&serverRecvCount, 1) - } else { - log.Printf("server recv middleware err=%v", err) - } - return err - } - }), - streamxserver.WithStreamSendMiddleware(func(next streamx.StreamSendEndpoint) streamx.StreamSendEndpoint { - return func(ctx context.Context, stream streamx.Stream, req any) (err error) { - err = next(ctx, stream, req) - if err == nil { - atomic.AddInt32(&serverSendCount, 1) - } else { - log.Printf("server send middleware err=%v", err) - } - return err - } - }), - streamxserver.WithStreamMiddleware( - // middleware example: server streaming mode - func(next streamx.StreamEndpoint) streamx.StreamEndpoint { - return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { - log.Printf("Server middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", - reqArgs.Req(), resArgs.Res(), streamArgs) - test.Assert(t, streamArgs.Stream() != nil) - test.Assert(t, ValidateMetadata(ctx)) - - log.Printf("Server handler start") - switch streamArgs.Stream().Mode() { - case streamx.StreamingUnary: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() != nil) - case streamx.StreamingClient: - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() != nil) - case streamx.StreamingServer: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - case streamx.StreamingBidirectional: - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - } - test.Assert(t, err == nil, err) - log.Printf("Server handler end") - - log.Printf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", - reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) - atomic.AddInt32(&serverStreamCount, 1) - return nil - } - }, - ), - ) - test.Assert(t, err == nil, err) - go func() { - err := svr.Run() - test.Assert(t, err == nil, err) - }() - defer svr.Stop() - test.WaitServerStart(addr) - - // create client - pingpongClient, err := NewPingPongClient( - "kitex.service.pingpong", - client.WithHostPorts(addr), - client.WithTransportProtocol(transport.TTHeaderFramed), - client.WithPayloadCodec(thrift.NewThriftCodecWithConfig(thrift.FastRead|thrift.FastWrite|thrift.EnableSkipDecoder)), - ) - test.Assert(t, err == nil, err) - // create streaming client - cp, _ := ttstream.NewClientProvider( - streamingServiceInfo, - ttstream.WithClientLongConnPool(ttstream.DefaultLongConnConfig), - ) - streamClient, err := NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithProvider(cp), - streamxclient.WithHostPorts(addr), - streamxclient.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { - return func(ctx context.Context, stream streamx.Stream, res any) (err error) { - err = next(ctx, stream, res) - log.Printf("Client recv middleware %v", res) - return err - } - }), - streamxclient.WithStreamSendMiddleware(func(next streamx.StreamSendEndpoint) streamx.StreamSendEndpoint { - return func(ctx context.Context, stream streamx.Stream, req any) (err error) { - err = next(ctx, stream, req) - log.Printf("Client send middleware %v", req) - return err - } - }), - streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { - return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { - // validate ctx - test.Assert(t, ValidateMetadata(ctx)) - - log.Printf("Client middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", - reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) - err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, err == nil, err) - log.Printf("Client middleware after next: reqArgs=%v resArgs=%v streamArgs=%v", - reqArgs.Req(), resArgs.Res(), streamArgs.Stream()) - - test.Assert(t, streamArgs.Stream() != nil) - switch streamArgs.Stream().Mode() { - case streamx.StreamingUnary: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() != nil) - case streamx.StreamingClient: - test.Assert(t, reqArgs.Req() == nil, reqArgs.Req()) - test.Assert(t, resArgs.Res() == nil) - case streamx.StreamingServer: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - case streamx.StreamingBidirectional: - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) - } - return err - } - }), - ) - test.Assert(t, err == nil, err) - - // prepare metainfo - ctx := context.Background() - ctx = SetMetadata(ctx) - - t.Logf("=== PingPong ===") - req := new(Request) - req.Message = "PingPong" - res, err := pingpongClient.PingPong(ctx, req) - test.Assert(t, err == nil, err) - test.Assert(t, req.Message == res.Message, res) - - t.Logf("=== Unary ===") - req = new(Request) - req.Type = 10000 - req.Message = "Unary" - res, err = streamClient.Unary(ctx, req) - test.Assert(t, err == nil, err) - test.Assert(t, req.Type == res.Type, res.Type) - test.Assert(t, req.Message == res.Message, res.Message) - test.Assert(t, serverRecvCount == 1, serverRecvCount) - test.Assert(t, serverSendCount == 1, serverSendCount) - atomic.AddInt32(&serverStreamCount, -1) - waitServerStreamDone() - serverRecvCount = 0 - serverSendCount = 0 - - // client stream - round := 5 - t.Logf("=== ClientStream ===") - cs, err := streamClient.ClientStream(ctx) - test.Assert(t, err == nil, err) - for i := 0; i < round; i++ { - req := new(Request) - req.Type = int32(i) - req.Message = "ClientStream" - err = cs.Send(ctx, req) - test.Assert(t, err == nil, err) - } - res, err = cs.CloseAndRecv(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, res.Message == "ClientStream", res.Message) - t.Logf("Client ClientStream CloseAndRecv: %v", res) - atomic.AddInt32(&serverStreamCount, -1) - waitServerStreamDone() - test.DeepEqual(t, serverRecvCount, int32(round)) - test.Assert(t, serverSendCount == 1, serverSendCount) - testHeaderAndTrailer(t, cs) - cs = nil - serverRecvCount = 0 - serverSendCount = 0 - runtime.GC() - - // server stream - t.Logf("=== ServerStream ===") - req = new(Request) - req.Message = "ServerStream" - ss, err := streamClient.ServerStream(ctx, req) - test.Assert(t, err == nil, err) - received := 0 - for { - res, err := ss.Recv(ctx) - if errors.Is(err, io.EOF) { - break - } - test.Assert(t, err == nil, err) - received++ - t.Logf("Client ServerStream recv: %v", res) - } - err = ss.CloseSend(ctx) - test.Assert(t, err == nil, err) - atomic.AddInt32(&serverStreamCount, -1) - waitServerStreamDone() - test.Assert(t, serverRecvCount == 1, serverRecvCount) - test.Assert(t, serverSendCount == int32(received), serverSendCount, received) - testHeaderAndTrailer(t, ss) - ss = nil - serverRecvCount = 0 - serverSendCount = 0 - runtime.GC() - - // bidi stream - t.Logf("=== BidiStream ===") - concurrent := 1 - round = 5 - for c := 0; c < concurrent; c++ { - atomic.AddInt32(&serverStreamCount, -1) - go func() { - bs, err := streamClient.BidiStream(ctx) - test.Assert(t, err == nil, err) - msg := "BidiStream" - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - for i := 0; i < round; i++ { - req := new(Request) - req.Message = msg - err := bs.Send(ctx, req) - test.Assert(t, err == nil, err) - } - err = bs.CloseSend(ctx) - test.Assert(t, err == nil, err) - }() - go func() { - defer wg.Done() - i := 0 - for { - res, err := bs.Recv(ctx) - t.Log(res, err) - if errors.Is(err, io.EOF) { - break - } - i++ - test.Assert(t, err == nil, err) - test.Assert(t, msg == res.Message, res.Message) - } - test.Assert(t, i == round, i) - }() - testHeaderAndTrailer(t, bs) - }() - } - waitServerStreamDone() - test.Assert(t, serverRecvCount == int32(concurrent*round), serverRecvCount) - test.Assert(t, serverSendCount == int32(concurrent*round), serverSendCount) - serverRecvCount = 0 - serverSendCount = 0 - runtime.GC() - - streamClient = nil -} - -func TestTTHeaderStreamingLongConn(t *testing.T) { - go func() { - log.Println(http.ListenAndServe("localhost:6060", nil)) - }() - - var addr = test.GetLocalAddress() - ln, _ := netpoll.CreateListener("tcp", addr) - defer ln.Close() - - // create server - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) - // register streamingService as ttstreaam provider - sp, _ := ttstream.NewServerProvider(streamingServiceInfo) - _ = svr.RegisterService( - streamingServiceInfo, - new(streamingService), - streamxserver.WithProvider(sp), - ) - go func() { - _ = svr.Run() - }() - defer svr.Stop() - test.WaitServerStart(addr) - - numGoroutine := runtime.NumGoroutine() - cp, _ := ttstream.NewClientProvider( - streamingServiceInfo, - ttstream.WithClientLongConnPool( - ttstream.LongConnConfig{MaxIdleTimeout: time.Second}, - ), - ) - streamClient, _ := NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithHostPorts(addr), - streamxclient.WithProvider(cp), - ) - ctx := context.Background() - msg := "BidiStream" - - t.Logf("checking only one connection be reused") - var wg sync.WaitGroup - for i := 0; i < 12; i++ { - wg.Add(1) - bs, err := streamClient.BidiStream(ctx) - test.Assert(t, err == nil, err) - req := new(Request) - req.Message = string(make([]byte, 1024)) - err = bs.Send(ctx, req) - test.Assert(t, err == nil, err) - res, err := bs.Recv(ctx) - test.Assert(t, err == nil, err) - err = bs.CloseSend(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, res.Message == req.Message, res.Message) - runtime.SetFinalizer(bs, func(_ any) { - wg.Done() - t.Logf("stream is finalized") - }) - bs = nil - runtime.GC() - wg.Wait() - } - - t.Logf("checking goroutines destroy") - // checking streaming goroutines - streams := 500 - for i := 0; i < streams; i++ { - wg.Add(1) - go func() { - bs, err := streamClient.BidiStream(ctx) - test.Assert(t, err == nil, err) - req := new(Request) - req.Message = msg - err = bs.Send(ctx, req) - test.Assert(t, err == nil, err) - go func() { - defer wg.Done() - res, err := bs.Recv(ctx) - test.Assert(t, err == nil, err) - err = bs.CloseSend(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, res.Message == msg, res.Message) - - testHeaderAndTrailer(t, bs) - }() - }() - } - wg.Wait() - for { - ng := runtime.NumGoroutine() - if ng-numGoroutine < 10 { - break - } - runtime.GC() - time.Sleep(time.Second) - t.Logf("current goroutines=%d, before =%d", ng, numGoroutine) - } -} - -func TestTTHeaderStreamingRecvTimeout(t *testing.T) { - var addr = test.GetLocalAddress() - ln, _ := netpoll.CreateListener("tcp", addr) - defer ln.Close() - - // create server - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) - // register streamingService as ttstreaam provider - sp, _ := ttstream.NewServerProvider(streamingServiceInfo) - _ = svr.RegisterService( - streamingServiceInfo, - new(streamingService), - streamxserver.WithProvider(sp), - ) - go func() { - _ = svr.Run() - }() - defer svr.Stop() - test.WaitServerStart(addr) - - cp, _ := ttstream.NewClientProvider( - streamingServiceInfo, - ttstream.WithClientLongConnPool( - ttstream.LongConnConfig{MaxIdleTimeout: time.Second}, - ), - ) - - // timeout by ctx itself - streamClient, _ := NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithHostPorts(addr), - streamxclient.WithProvider(cp), - ) - ctx := context.Background() - bs, err := streamClient.BidiStream(ctx) - test.Assert(t, err == nil, err) - req := new(Request) - req.Message = string(make([]byte, 1024)) - err = bs.Send(ctx, req) - test.Assert(t, err == nil, err) - ctx, cancel := context.WithCancel(ctx) - cancel() - _, err = bs.Recv(ctx) - test.Assert(t, err != nil, err) - t.Logf("recv timeout error: %v", err) - err = bs.CloseSend(ctx) - test.Assert(t, err == nil, err) - - // timeout by client WithRecvTimeout - streamClient, _ = NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithHostPorts(addr), - streamxclient.WithProvider(cp), - streamxclient.WithRecvTimeout(time.Nanosecond), - ) - ctx = context.Background() - bs, err = streamClient.BidiStream(ctx) - test.Assert(t, err == nil, err) - req = new(Request) - req.Message = string(make([]byte, 1024)) - err = bs.Send(ctx, req) - test.Assert(t, err == nil, err) - _, err = bs.Recv(ctx) - test.Assert(t, err != nil, err) - t.Logf("recv timeout error: %v", err) - err = bs.CloseSend(ctx) - test.Assert(t, err == nil, err) -} - -func TestTTHeaderStreamingServerGoroutines(t *testing.T) { - var addr = test.GetLocalAddress() - ln, _ := netpoll.CreateListener("tcp", addr) - defer ln.Close() - - // create server - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) - // register streamingService as ttstreaam provider - sp, _ := ttstream.NewServerProvider(streamingServiceInfo) - _ = svr.RegisterService( - streamingServiceInfo, - new(streamingService), - streamxserver.WithProvider(sp), - ) - go func() { - _ = svr.Run() - }() - defer svr.Stop() - test.WaitServerStart(addr) - - cp, _ := ttstream.NewClientProvider( - streamingServiceInfo, - ttstream.WithClientLongConnPool(ttstream.LongConnConfig{MaxIdleTimeout: time.Second}), - ) - streamClient, _ := NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithHostPorts(addr), - streamxclient.WithProvider(cp), - ) - - oldNGs := runtime.NumGoroutine() - streams := 100 - streamList := make([]streamx.ServerStream, streams) - for i := 0; i < streams; i++ { - ctx := context.Background() - bs, err := streamClient.BidiStream(ctx) - test.Assert(t, err == nil, err) - streamList[i] = bs - } - ngs := runtime.NumGoroutine() - test.Assert(t, ngs > streams, ngs) - for i := 0; i < streams; i++ { - streamList[i] = nil - } - streamList = nil - for ngs-oldNGs > 10 { - runtime.GC() - ngs = runtime.NumGoroutine() - time.Sleep(time.Millisecond * 100) - } -} - -func BenchmarkTTHeaderStreaming(b *testing.B) { - klog.SetLevel(klog.LevelWarn) - var addr = test.GetLocalAddress() - ln, _ := netpoll.CreateListener("tcp", addr) - defer ln.Close() - - // create server - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) - // register streamingService as ttstreaam provider - sp, _ := ttstream.NewServerProvider(streamingServiceInfo) - _ = svr.RegisterService(streamingServiceInfo, new(streamingService), streamxserver.WithProvider(sp)) - go func() { - _ = svr.Run() - }() - defer svr.Stop() - test.WaitServerStart(addr) - - streamClient, _ := NewStreamingClient("kitex.service.streaming", streamxclient.WithHostPorts(addr)) - ctx := context.Background() - bs, err := streamClient.BidiStream(ctx) - if err != nil { - b.Fatal(err) - } - msg := "BidiStream" - var wg sync.WaitGroup - wg.Add(1) - b.ResetTimer() - b.ReportAllocs() - for i := 0; i < b.N; i++ { - req := new(Request) - req.Message = msg - err := bs.Send(ctx, req) - if err != nil { - b.Fatal(err) - } - res, err := bs.Recv(ctx) - if errors.Is(err, io.EOF) { - break - } - _ = res - } - err = bs.CloseSend(ctx) - if err != nil { - b.Fatal(err) - } -} diff --git a/pkg/streamx/provider/ttstream/ttstream_common_test.go b/pkg/streamx/provider/ttstream/ttstream_common_test.go deleted file mode 100644 index 242a75cd10..0000000000 --- a/pkg/streamx/provider/ttstream/ttstream_common_test.go +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream_test - -import ( - "context" - - "github.com/bytedance/gopkg/cloud/metainfo" -) - -var persistKVs = map[string]string{ - "p1": "v1", - "p2": "v2", - "p3": "v3", -} - -var transitKVs = map[string]string{ - "t1": "v1", - "t2": "v2", - "t3": "v3", -} - -func SetMetadata(ctx context.Context) context.Context { - for k, v := range persistKVs { - ctx = metainfo.WithPersistentValue(ctx, k, v) - } - for k, v := range transitKVs { - ctx = metainfo.WithValue(ctx, k, v) - } - return ctx -} - -func ValidateMetadata(ctx context.Context) bool { - for k, v := range persistKVs { - _v, _ := metainfo.GetPersistentValue(ctx, k) - if _v != v { - return false - } - } - for k, v := range transitKVs { - _v, _ := metainfo.GetValue(ctx, k) - if _v != v { - return false - } - } - return true -} diff --git a/pkg/streamx/provider/ttstream/ttstream_error_test.go b/pkg/streamx/provider/ttstream/ttstream_error_test.go deleted file mode 100644 index 7153bd60c3..0000000000 --- a/pkg/streamx/provider/ttstream/ttstream_error_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package ttstream_test - -import ( - "context" - "testing" - "time" - - "github.com/cloudwego/gopkg/protocol/thrift" - "github.com/cloudwego/netpoll" - - "github.com/cloudwego/kitex/client/streamxclient" - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/kerrors" - "github.com/cloudwego/kitex/pkg/klog" - "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" - "github.com/cloudwego/kitex/server" - "github.com/cloudwego/kitex/server/streamxserver" -) - -const ( - normalErr int32 = iota + 1 - bizErr -) - -var ( - testCode = int32(10001) - testMsg = "biz testMsg" - testExtra = map[string]string{ - "testKey": "testVal", - } - normalErrMsg = "normal error" -) - -func assertNormalErr(t *testing.T, err error) { - ex, ok := err.(*thrift.ApplicationException) - test.Assert(t, ok, err) - test.Assert(t, ex.TypeID() == remote.InternalError, ex.TypeID()) - test.Assert(t, ex.Msg() == "biz error: "+normalErrMsg, ex.Msg()) -} - -func assertBizErr(t *testing.T, err error) { - bizIntf, ok := kerrors.FromBizStatusError(err) - test.Assert(t, ok) - test.Assert(t, bizIntf.BizStatusCode() == testCode, bizIntf.BizStatusCode()) - test.Assert(t, bizIntf.BizMessage() == testMsg, bizIntf.BizMessage()) - test.DeepEqual(t, bizIntf.BizExtra(), testExtra) -} - -func TestTTHeaderStreamingErrorHandling(t *testing.T) { - klog.SetLevel(klog.LevelDebug) - var addr = test.GetLocalAddress() - ln, err := netpoll.CreateListener("tcp", addr) - test.Assert(t, err == nil, err) - defer ln.Close() - - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) - sp, err := ttstream.NewServerProvider(streamingServiceInfo) - test.Assert(t, err == nil, err) - err = svr.RegisterService( - streamingServiceInfo, - new(streamingService), - streamxserver.WithProvider(sp), - ) - test.Assert(t, err == nil, err) - go func() { - err := svr.Run() - test.Assert(t, err == nil, err) - }() - defer svr.Stop() - test.WaitServerStart(addr) - - streamClient, err := NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithHostPorts(addr), - ) - test.Assert(t, err == nil, err) - - t.Logf("=== UnaryWithErr normalErr ===") - req := new(Request) - req.Type = normalErr - res, err := streamClient.UnaryWithErr(context.Background(), req) - test.Assert(t, res == nil, res) - test.Assert(t, err != nil, err) - assertNormalErr(t, err) - - t.Logf("=== UnaryWithErr bizErr ===") - req = new(Request) - req.Type = bizErr - res, err = streamClient.UnaryWithErr(context.Background(), req) - test.Assert(t, res == nil, res) - test.Assert(t, err != nil, err) - assertBizErr(t, err) - - t.Logf("=== ClientStreamWithErr normalErr ===") - ctx := context.Background() - cliStream, err := streamClient.ClientStreamWithErr(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, cliStream != nil, cliStream) - req = new(Request) - req.Type = normalErr - err = cliStream.Send(ctx, req) - test.Assert(t, err == nil, err) - res, err = cliStream.CloseAndRecv(ctx) - test.Assert(t, res == nil, res) - test.Assert(t, err != nil, err) - assertNormalErr(t, err) - - t.Logf("=== ClientStreamWithErr bizErr ===") - ctx = context.Background() - cliStream, err = streamClient.ClientStreamWithErr(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, cliStream != nil, cliStream) - req = new(Request) - req.Type = bizErr - err = cliStream.Send(ctx, req) - test.Assert(t, err == nil, err) - res, err = cliStream.CloseAndRecv(ctx) - test.Assert(t, res == nil, res) - test.Assert(t, err != nil, err) - assertBizErr(t, err) - - t.Logf("=== ServerStreamWithErr normalErr ===") - ctx = context.Background() - req = new(Request) - req.Type = normalErr - svrStream, err := streamClient.ServerStreamWithErr(ctx, req) - test.Assert(t, err == nil, err) - test.Assert(t, svrStream != nil, svrStream) - res, err = svrStream.Recv(ctx) - test.Assert(t, res == nil, res) - test.Assert(t, err != nil, err) - assertNormalErr(t, err) - - t.Logf("=== ServerStreamWithErr bizErr ===") - ctx = context.Background() - req = new(Request) - req.Type = bizErr - svrStream, err = streamClient.ServerStreamWithErr(ctx, req) - test.Assert(t, err == nil, err) - test.Assert(t, svrStream != nil, svrStream) - res, err = svrStream.Recv(ctx) - test.Assert(t, res == nil, res) - test.Assert(t, err != nil, err) - assertBizErr(t, err) - - t.Logf("=== BidiStreamWithErr normalErr ===") - ctx = context.Background() - bidiStream, err := streamClient.BidiStreamWithErr(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, bidiStream != nil, bidiStream) - req = new(Request) - req.Type = normalErr - err = bidiStream.Send(ctx, req) - test.Assert(t, err == nil, err) - res, err = bidiStream.Recv(ctx) - test.Assert(t, res == nil, res) - test.Assert(t, err != nil, err) - assertNormalErr(t, err) - - t.Logf("=== BidiStreamWithErr bizErr ===") - ctx = context.Background() - bidiStream, err = streamClient.BidiStreamWithErr(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, bidiStream != nil, bidiStream) - req = new(Request) - req.Type = bizErr - err = bidiStream.Send(ctx, req) - test.Assert(t, err == nil, err) - res, err = bidiStream.Recv(ctx) - test.Assert(t, res == nil, res) - test.Assert(t, err != nil, err) - assertBizErr(t, err) -} diff --git a/pkg/streamx/stream.go b/pkg/streamx/stream.go index 42a7b043cd..e478a65024 100644 --- a/pkg/streamx/stream.go +++ b/pkg/streamx/stream.go @@ -22,12 +22,14 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" ) -var _ ServerStreamingClient[int] = (*GenericClientStream[int, int])(nil) -var _ ClientStreamingClient[int, int] = (*GenericClientStream[int, int])(nil) -var _ BidiStreamingClient[int, int] = (*GenericClientStream[int, int])(nil) -var _ ServerStreamingServer[int] = (*GenericServerStream[int, int])(nil) -var _ ClientStreamingServer[int, int] = (*GenericServerStream[int, int])(nil) -var _ BidiStreamingServer[int, int] = (*GenericServerStream[int, int])(nil) +var ( + _ ServerStreamingClient[int] = (*GenericClientStream[int, int])(nil) + _ ClientStreamingClient[int, int] = (*GenericClientStream[int, int])(nil) + _ BidiStreamingClient[int, int] = (*GenericClientStream[int, int])(nil) + _ ServerStreamingServer[int] = (*GenericServerStream[int, int])(nil) + _ ClientStreamingServer[int, int] = (*GenericServerStream[int, int])(nil) + _ BidiStreamingServer[int, int] = (*GenericServerStream[int, int])(nil) +) type StreamingMode = serviceinfo.StreamingMode @@ -136,7 +138,7 @@ type ClientStreamingClient[Req, Res any] interface { type ClientStreamingServer[Req, Res any] interface { Recv(ctx context.Context) (*Req, error) - //SendAndClose(ctx context.Context, res *Res) error + // SendAndClose(ctx context.Context, res *Res) error ServerStream ServerStreamMetadata } diff --git a/pkg/streamx/stream_middleware.go b/pkg/streamx/stream_middleware.go index ed2f8e948b..bce62d65a7 100644 --- a/pkg/streamx/stream_middleware.go +++ b/pkg/streamx/stream_middleware.go @@ -27,14 +27,20 @@ type StreamHandler struct { StreamSendMiddleware StreamSendMiddleware } -type StreamEndpoint func(ctx context.Context, streamArgs StreamArgs, reqArgs StreamReqArgs, resArgs StreamResArgs) (err error) -type StreamMiddleware func(next StreamEndpoint) StreamEndpoint +type ( + StreamEndpoint func(ctx context.Context, streamArgs StreamArgs, reqArgs StreamReqArgs, resArgs StreamResArgs) (err error) + StreamMiddleware func(next StreamEndpoint) StreamEndpoint +) -type StreamRecvEndpoint func(ctx context.Context, stream Stream, res any) (err error) -type StreamSendEndpoint func(ctx context.Context, stream Stream, req any) (err error) +type ( + StreamRecvEndpoint func(ctx context.Context, stream Stream, res any) (err error) + StreamSendEndpoint func(ctx context.Context, stream Stream, req any) (err error) +) -type StreamRecvMiddleware func(next StreamRecvEndpoint) StreamRecvEndpoint -type StreamSendMiddleware func(next StreamSendEndpoint) StreamSendEndpoint +type ( + StreamRecvMiddleware func(next StreamRecvEndpoint) StreamRecvEndpoint + StreamSendMiddleware func(next StreamSendEndpoint) StreamSendEndpoint +) func StreamMiddlewareChain(mws ...StreamMiddleware) StreamMiddleware { return func(next StreamEndpoint) StreamEndpoint { diff --git a/pkg/streamx/stream_middleware_internal.go b/pkg/streamx/stream_middleware_internal.go index ab2d9f35f0..d06d7d491d 100644 --- a/pkg/streamx/stream_middleware_internal.go +++ b/pkg/streamx/stream_middleware_internal.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx import ( diff --git a/pkg/streamx/streamx_common_test.go b/pkg/streamx/streamx_common_test.go new file mode 100644 index 0000000000..7a154076d4 --- /dev/null +++ b/pkg/streamx/streamx_common_test.go @@ -0,0 +1,95 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamx_test + +import ( + "context" + "testing" + + "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/gopkg/protocol/thrift" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" +) + +const ( + normalErr int32 = iota + 1 + bizErr +) + +var ( + testCode = int32(10001) + testMsg = "biz testMsg" + testExtra = map[string]string{ + "testKey": "testVal", + } + normalErrMsg = "normal error" + + persistKVs = map[string]string{ + "p1": "v1", + "p2": "v2", + "p3": "v3", + } + transitKVs = map[string]string{ + "t1": "v1", + "t2": "v2", + "t3": "v3", + } +) + +func setMetadata(ctx context.Context) context.Context { + for k, v := range persistKVs { + ctx = metainfo.WithPersistentValue(ctx, k, v) + } + for k, v := range transitKVs { + ctx = metainfo.WithValue(ctx, k, v) + } + return ctx +} + +func validateMetadata(ctx context.Context) bool { + for k, v := range persistKVs { + _v, _ := metainfo.GetPersistentValue(ctx, k) + if _v != v { + return false + } + } + for k, v := range transitKVs { + _v, _ := metainfo.GetValue(ctx, k) + if _v != v { + return false + } + } + return true +} + +func assertNormalErr(t *testing.T, err error) { + ex, ok := err.(*thrift.ApplicationException) + test.Assert(t, ok, err) + test.Assert(t, ex.TypeID() == remote.InternalError, ex.TypeID()) + test.Assert(t, ex.Msg() == "biz error: "+normalErrMsg, ex.Msg()) +} + +func assertBizErr(t *testing.T, err error) { + bizIntf, ok := kerrors.FromBizStatusError(err) + test.Assert(t, ok) + test.Assert(t, bizIntf.BizStatusCode() == testCode, bizIntf.BizStatusCode()) + test.Assert(t, bizIntf.BizMessage() == testMsg, bizIntf.BizMessage()) + test.DeepEqual(t, bizIntf.BizExtra(), testExtra) +} diff --git a/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go b/pkg/streamx/streamx_gen_codec_test.go similarity index 99% rename from pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go rename to pkg/streamx/streamx_gen_codec_test.go index 2784566ae9..67793454aa 100644 --- a/pkg/streamx/provider/ttstream/ttstream_gen_codec_test.go +++ b/pkg/streamx/streamx_gen_codec_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package ttstream_test +package streamx_test import ( "bytes" @@ -22,9 +22,9 @@ import ( "reflect" "strings" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - "github.com/apache/thrift/lib/go/thrift" + + "github.com/cloudwego/kitex/pkg/protocol/bthrift" kutils "github.com/cloudwego/kitex/pkg/utils" ) diff --git a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go b/pkg/streamx/streamx_gen_service_test.go similarity index 84% rename from pkg/streamx/provider/ttstream/ttstream_gen_service_test.go rename to pkg/streamx/streamx_gen_service_test.go index 59062f789f..ea774a2cec 100644 --- a/pkg/streamx/provider/ttstream/ttstream_gen_service_test.go +++ b/pkg/streamx/streamx_gen_service_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package ttstream_test +package streamx_test import ( "context" @@ -31,15 +31,6 @@ import ( // === gen code === -// --- Define Header and Trailer type --- -type ClientStreamingServer[Req, Res any] streamx.ClientStreamingServer[Req, Res] -type ServerStreamingServer[Res any] streamx.ServerStreamingServer[Res] -type BidiStreamingServer[Req, Res any] streamx.BidiStreamingServer[Req, Res] - -type ClientStreamingClient[Req, Res any] streamx.ClientStreamingClient[Req, Res] -type ServerStreamingClient[Res any] streamx.ServerStreamingClient[Res] -type BidiStreamingClient[Req, Res any] streamx.BidiStreamingClient[Req, Res] - // --- Define Service Method handler --- var pingpongServiceInfo = &serviceinfo.ServiceInfo{ ServiceName: "kitex.service.pingpong", @@ -194,13 +185,13 @@ type PingPongServerInterface interface { } type StreamingServerInterface interface { Unary(ctx context.Context, req *Request) (*Response, error) - ClientStream(ctx context.Context, stream ClientStreamingServer[Request, Response]) (*Response, error) - ServerStream(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error - BidiStream(ctx context.Context, stream BidiStreamingServer[Request, Response]) error + ClientStream(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response]) (*Response, error) + ServerStream(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error + BidiStream(ctx context.Context, stream streamx.BidiStreamingServer[Request, Response]) error UnaryWithErr(ctx context.Context, req *Request) (*Response, error) - ClientStreamWithErr(ctx context.Context, stream ClientStreamingServer[Request, Response]) (*Response, error) - ServerStreamWithErr(ctx context.Context, req *Request, stream ServerStreamingServer[Response]) error - BidiStreamWithErr(ctx context.Context, stream BidiStreamingServer[Request, Response]) error + ClientStreamWithErr(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response]) (*Response, error) + ServerStreamWithErr(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error + BidiStreamWithErr(ctx context.Context, stream streamx.BidiStreamingServer[Request, Response]) error } // --- Define Client Implementation Interface --- @@ -211,23 +202,25 @@ type PingPongClientInterface interface { type StreamingClientInterface interface { Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream ClientStreamingClient[Request, Response], err error) + stream streamx.ClientStreamingClient[Request, Response], err error) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream ServerStreamingClient[Response], err error) + stream streamx.ServerStreamingClient[Response], err error) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream BidiStreamingClient[Request, Response], err error) + stream streamx.BidiStreamingClient[Request, Response], err error) UnaryWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) ClientStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream ClientStreamingClient[Request, Response], err error) + stream streamx.ClientStreamingClient[Request, Response], err error) ServerStreamWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream ServerStreamingClient[Response], err error) + stream streamx.ServerStreamingClient[Response], err error) BidiStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream BidiStreamingClient[Request, Response], err error) + stream streamx.BidiStreamingClient[Request, Response], err error) } // --- Define Client Implementation --- -var _ StreamingClientInterface = (*kClient)(nil) -var _ PingPongClientInterface = (*kClient)(nil) +var ( + _ StreamingClientInterface = (*kClient)(nil) + _ PingPongClientInterface = (*kClient)(nil) +) type kClient struct { caller client.Client @@ -254,19 +247,23 @@ func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...stream return res, nil } -func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream ClientStreamingClient[Request, Response], err error) { +func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream streamx.ClientStreamingClient[Request, Response], err error, +) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) } func (c *kClient) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream ServerStreamingClient[Response], err error) { + stream streamx.ServerStreamingClient[Response], err error, +) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) } func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream BidiStreamingClient[Request, Response], err error) { + stream streamx.BidiStreamingClient[Request, Response], err error, +) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) } @@ -281,19 +278,23 @@ func (c *kClient) UnaryWithErr(ctx context.Context, req *Request, callOptions .. return res, nil } -func (c *kClient) ClientStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream ClientStreamingClient[Request, Response], err error) { +func (c *kClient) ClientStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + stream streamx.ClientStreamingClient[Request, Response], err error, +) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingClient, "ClientStreamWithErr", nil, nil, callOptions...) } func (c *kClient) ServerStreamWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream ServerStreamingClient[Response], err error) { + stream streamx.ServerStreamingClient[Response], err error, +) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingServer, "ServerStreamWithErr", req, nil, callOptions...) } func (c *kClient) BidiStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream BidiStreamingClient[Request, Response], err error) { + stream streamx.BidiStreamingClient[Request, Response], err error, +) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStreamWithErr", nil, nil, callOptions...) } diff --git a/pkg/streamx/provider/ttstream/ttstream_server_test.go b/pkg/streamx/streamx_user_service_test.go similarity index 88% rename from pkg/streamx/provider/ttstream/ttstream_server_test.go rename to pkg/streamx/streamx_user_service_test.go index a0366af090..f701a31dc6 100644 --- a/pkg/streamx/provider/ttstream/ttstream_server_test.go +++ b/pkg/streamx/streamx_user_service_test.go @@ -14,21 +14,25 @@ * limitations under the License. */ -package ttstream_test +package streamx_test import ( "context" "errors" "io" + "testing" + "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" ) -type pingpongService struct{} -type streamingService struct{} +type ( + pingpongService struct{} + streamingService struct{} +) const ( headerKey = "header1" @@ -37,6 +41,15 @@ const ( trailerVal = "value1" ) +func testHeaderAndTrailer(t *testing.T, stream streamx.ClientStreamMetadata) { + hd, err := stream.Header() + test.Assert(t, err == nil, err) + test.Assert(t, hd[headerKey] == headerVal, hd) + tl, err := stream.Trailer() + test.Assert(t, err == nil, err) + test.Assert(t, tl[trailerKey] == trailerVal, tl) +} + func (si *streamingService) setHeaderAndTrailer(stream streamx.ServerStreamMetadata) error { err := stream.SetTrailer(streamx.Trailer{trailerKey: trailerVal}) if err != nil { @@ -63,7 +76,8 @@ func (si *streamingService) Unary(ctx context.Context, req *Request) (*Response, } func (si *streamingService) ClientStream(ctx context.Context, - stream streamx.ClientStreamingServer[Request, Response]) (*Response, error) { + stream streamx.ClientStreamingServer[Request, Response], +) (*Response, error) { var msg string klog.Infof("Server ClientStream start") defer klog.Infof("Server ClientStream end") @@ -87,7 +101,8 @@ func (si *streamingService) ClientStream(ctx context.Context, } func (si *streamingService) ServerStream(ctx context.Context, req *Request, - stream streamx.ServerStreamingServer[Response]) error { + stream streamx.ServerStreamingServer[Response], +) error { klog.Infof("Server ServerStream: req={%v}", req) if err := si.setHeaderAndTrailer(stream); err != nil { @@ -108,7 +123,8 @@ func (si *streamingService) ServerStream(ctx context.Context, req *Request, } func (si *streamingService) BidiStream(ctx context.Context, - stream streamx.BidiStreamingServer[Request, Response]) error { + stream streamx.BidiStreamingServer[Request, Response], +) error { ktx.RegisterCancelCallback(ctx, func() { klog.Debugf("RegisterCancelCallback work!") }) diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go new file mode 100644 index 0000000000..8d5a1b92ef --- /dev/null +++ b/pkg/streamx/streamx_user_test.go @@ -0,0 +1,596 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamx_test + +import ( + "context" + "errors" + "io" + "log" + "net/http" + _ "net/http/pprof" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/client/streamxclient" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" + "github.com/cloudwego/kitex/server" + "github.com/cloudwego/kitex/server/streamxserver" + "github.com/cloudwego/kitex/transport" +) + +var providerTestCases []testCase + +type testCase struct { + Name string + ClientProvider streamx.ClientProvider + ServerProvider streamx.ServerProvider +} + +func init() { + klog.SetLevel(klog.LevelWarn) + + sp, _ := ttstream.NewServerProvider(streamingServiceInfo) + cp, _ := ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientLongConnPool(ttstream.LongConnConfig{MaxIdleTimeout: time.Millisecond * 100})) + providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_LongConn", ClientProvider: cp, ServerProvider: sp}) + cp, _ = ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientShortConnPool()) + providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_ShortConn", ClientProvider: cp, ServerProvider: sp}) + cp, _ = ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientMuxConnPool()) + providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_Mux", ClientProvider: cp, ServerProvider: sp}) +} + +func TestMain(m *testing.M) { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + m.Run() +} + +func TestStreamingBasic(t *testing.T) { + for _, tc := range providerTestCases { + t.Run(tc.Name, func(t *testing.T) { + // === prepare test environment === + addr := test.GetLocalAddress() + ln, err := netpoll.CreateListener("tcp", addr) + test.Assert(t, err == nil, err) + defer ln.Close() + // create server + var serverStreamCount int32 + waitServerStreamDone := func() { + for atomic.LoadInt32(&serverStreamCount) != 0 { + t.Logf("waitServerStreamDone: %d", atomic.LoadInt32(&serverStreamCount)) + time.Sleep(time.Millisecond * 10) + } + } + var serverRecvCount int32 + var serverSendCount int32 + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + // register pingpong service + err = svr.RegisterService(pingpongServiceInfo, new(pingpongService)) + test.Assert(t, err == nil, err) + // register streamingService as ttstreaam provider + err = svr.RegisterService( + streamingServiceInfo, + new(streamingService), + streamxserver.WithProvider(tc.ServerProvider), + streamxserver.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { + return func(ctx context.Context, stream streamx.Stream, res any) (err error) { + err = next(ctx, stream, res) + if err == nil { + atomic.AddInt32(&serverRecvCount, 1) + } + return err + } + }), + streamxserver.WithStreamSendMiddleware(func(next streamx.StreamSendEndpoint) streamx.StreamSendEndpoint { + return func(ctx context.Context, stream streamx.Stream, req any) (err error) { + err = next(ctx, stream, req) + if err == nil { + atomic.AddInt32(&serverSendCount, 1) + } + return err + } + }), + streamxserver.WithStreamMiddleware( + // middleware example: server streaming mode + func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + log.Printf("Server middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", + reqArgs.Req(), resArgs.Res(), streamArgs) + test.Assert(t, streamArgs.Stream() != nil) + test.Assert(t, validateMetadata(ctx)) + + switch streamArgs.Stream().Mode() { + case streamx.StreamingUnary: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() != nil || err != nil) + case streamx.StreamingClient: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() != nil || err != nil) + case streamx.StreamingServer: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingBidirectional: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + err = next(ctx, streamArgs, reqArgs, resArgs) + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + } + + log.Printf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v err=%v", + reqArgs.Req(), resArgs.Res(), streamArgs.Stream(), err) + atomic.AddInt32(&serverStreamCount, 1) + return err + } + }, + ), + ) + test.Assert(t, err == nil, err) + go func() { + err := svr.Run() + test.Assert(t, err == nil, err) + }() + defer svr.Stop() + test.WaitServerStart(addr) + + // create client + pingpongClient, err := NewPingPongClient( + "kitex.service.pingpong", + client.WithHostPorts(addr), + client.WithTransportProtocol(transport.TTHeaderFramed), + client.WithPayloadCodec(thrift.NewThriftCodecWithConfig(thrift.FastRead|thrift.FastWrite|thrift.EnableSkipDecoder)), + ) + test.Assert(t, err == nil, err) + // create streaming client + streamClient, err := NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithProvider(tc.ClientProvider), + streamxclient.WithHostPorts(addr), + streamxclient.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { + return func(ctx context.Context, stream streamx.Stream, res any) (err error) { + err = next(ctx, stream, res) + return err + } + }), + streamxclient.WithStreamSendMiddleware(func(next streamx.StreamSendEndpoint) streamx.StreamSendEndpoint { + return func(ctx context.Context, stream streamx.Stream, req any) (err error) { + err = next(ctx, stream, req) + return err + } + }), + streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + // validate ctx + test.Assert(t, validateMetadata(ctx)) + + err = next(ctx, streamArgs, reqArgs, resArgs) + + test.Assert(t, streamArgs.Stream() != nil) + switch streamArgs.Stream().Mode() { + case streamx.StreamingUnary: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() != nil || err != nil) + case streamx.StreamingClient: + test.Assert(t, reqArgs.Req() == nil, reqArgs.Req()) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingServer: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingBidirectional: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + } + return err + } + }), + ) + test.Assert(t, err == nil, err) + + // prepare metainfo + ctx := context.Background() + ctx = setMetadata(ctx) + + t.Logf("=== PingPong ===") + req := new(Request) + req.Message = "PingPong" + res, err := pingpongClient.PingPong(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, req.Message == res.Message, res) + + t.Logf("=== Unary ===") + req = new(Request) + req.Type = 10000 + req.Message = "Unary" + res, err = streamClient.Unary(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, req.Type == res.Type, res.Type) + test.Assert(t, req.Message == res.Message, res.Message) + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(1)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(1)) + atomic.StoreInt32(&serverRecvCount, 0) + atomic.StoreInt32(&serverSendCount, 0) + + // client stream + round := 5 + t.Logf("=== ClientStream ===") + cs, err := streamClient.ClientStream(ctx) + test.Assert(t, err == nil, err) + for i := 0; i < round; i++ { + req := new(Request) + req.Type = int32(i) + req.Message = "ClientStream" + err = cs.Send(ctx, req) + test.Assert(t, err == nil, err) + } + res, err = cs.CloseAndRecv(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == "ClientStream", res.Message) + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + testHeaderAndTrailer(t, cs) + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(1)) + atomic.StoreInt32(&serverRecvCount, 0) + atomic.StoreInt32(&serverSendCount, 0) + cs = nil + runtime.GC() + + // server stream + t.Logf("=== ServerStream ===") + req = new(Request) + req.Message = "ServerStream" + ss, err := streamClient.ServerStream(ctx, req) + test.Assert(t, err == nil, err) + received := 0 + for { + res, err := ss.Recv(ctx) + if errors.Is(err, io.EOF) { + break + } + test.Assert(t, err == nil, err) + received++ + t.Logf("Client ServerStream recv: %v", res) + } + err = ss.CloseSend(ctx) + test.Assert(t, err == nil, err) + atomic.AddInt32(&serverStreamCount, -1) + waitServerStreamDone() + testHeaderAndTrailer(t, ss) + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(1)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(received)) + atomic.StoreInt32(&serverRecvCount, 0) + atomic.StoreInt32(&serverSendCount, 0) + ss = nil + runtime.GC() + + // bidi stream + t.Logf("=== BidiStream ===") + concurrent := 32 + round = 5 + for c := 0; c < concurrent; c++ { + atomic.AddInt32(&serverStreamCount, -1) + go func() { + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + msg := "BidiStream" + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < round; i++ { + req := new(Request) + req.Message = msg + err := bs.Send(ctx, req) + test.Assert(t, err == nil, err) + } + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + }() + go func() { + defer wg.Done() + i := 0 + for { + res, err := bs.Recv(ctx) + if errors.Is(err, io.EOF) { + break + } + i++ + test.Assert(t, err == nil, err) + test.Assert(t, msg == res.Message, res.Message) + } + test.Assert(t, i == round, i) + }() + testHeaderAndTrailer(t, bs) + }() + } + waitServerStreamDone() + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(concurrent*round)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(concurrent*round)) + atomic.StoreInt32(&serverRecvCount, 0) + atomic.StoreInt32(&serverSendCount, 0) + runtime.GC() + + t.Logf("=== UnaryWithErr normalErr ===") + req = new(Request) + req.Type = normalErr + res, err = streamClient.UnaryWithErr(ctx, req) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertNormalErr(t, err) + + t.Logf("=== UnaryWithErr bizErr ===") + req = new(Request) + req.Type = bizErr + res, err = streamClient.UnaryWithErr(ctx, req) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertBizErr(t, err) + + t.Logf("=== ClientStreamWithErr normalErr ===") + cliStream, err := streamClient.ClientStreamWithErr(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, cliStream != nil, cliStream) + req = new(Request) + req.Type = normalErr + err = cliStream.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err = cliStream.CloseAndRecv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertNormalErr(t, err) + + t.Logf("=== ClientStreamWithErr bizErr ===") + cliStream, err = streamClient.ClientStreamWithErr(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, cliStream != nil, cliStream) + req = new(Request) + req.Type = bizErr + err = cliStream.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err = cliStream.CloseAndRecv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertBizErr(t, err) + + t.Logf("=== ServerStreamWithErr normalErr ===") + req = new(Request) + req.Type = normalErr + svrStream, err := streamClient.ServerStreamWithErr(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, svrStream != nil, svrStream) + res, err = svrStream.Recv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertNormalErr(t, err) + + t.Logf("=== ServerStreamWithErr bizErr ===") + req = new(Request) + req.Type = bizErr + svrStream, err = streamClient.ServerStreamWithErr(ctx, req) + test.Assert(t, err == nil, err) + test.Assert(t, svrStream != nil, svrStream) + res, err = svrStream.Recv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertBizErr(t, err) + + t.Logf("=== BidiStreamWithErr normalErr ===") + bidiStream, err := streamClient.BidiStreamWithErr(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, bidiStream != nil, bidiStream) + req = new(Request) + req.Type = normalErr + err = bidiStream.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err = bidiStream.Recv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertNormalErr(t, err) + + t.Logf("=== BidiStreamWithErr bizErr ===") + bidiStream, err = streamClient.BidiStreamWithErr(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, bidiStream != nil, bidiStream) + req = new(Request) + req.Type = bizErr + err = bidiStream.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err = bidiStream.Recv(ctx) + test.Assert(t, res == nil, res) + test.Assert(t, err != nil, err) + assertBizErr(t, err) + + t.Logf("=== Timeout by Ctx ===") + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + req = new(Request) + req.Message = string(make([]byte, 1024)) + err = bs.Send(ctx, req) + test.Assert(t, err == nil, err) + nctx, cancel := context.WithCancel(ctx) + cancel() + _, err = bs.Recv(nctx) + test.Assert(t, err != nil, err) + t.Logf("recv timeout error: %v", err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + + // timeout by client WithRecvTimeout + t.Logf("=== Timeout by WithRecvTimeout ===") + streamClient, _ = NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + streamxclient.WithProvider(tc.ClientProvider), + streamxclient.WithRecvTimeout(time.Nanosecond), + ) + bs, err = streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + req = new(Request) + req.Message = string(make([]byte, 1024)) + err = bs.Send(ctx, req) + test.Assert(t, err == nil, err) + _, err = bs.Recv(ctx) + test.Assert(t, err != nil, err) + t.Logf("recv timeout error: %v", err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + + streamClient = nil + }) + } +} + +func TestStreamingGoroutineLeak(t *testing.T) { + for _, tc := range providerTestCases { + t.Run(tc.Name, func(t *testing.T) { + addr := test.GetLocalAddress() + ln, _ := netpoll.CreateListener("tcp", addr) + defer ln.Close() + + // create server + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + var streamStarted int32 + waitStreamStarted := func(streamWaited int) { + for { + stated, waited := atomic.LoadInt32(&streamStarted), int32(streamWaited) + if stated >= waited { + return + } + t.Logf("streamStarted=%d < streamWaited=%d", stated, waited) + time.Sleep(time.Millisecond * 10) + } + } + _ = svr.RegisterService( + streamingServiceInfo, new(streamingService), + streamxserver.WithProvider(tc.ServerProvider), + streamxserver.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + atomic.AddInt32(&streamStarted, 1) + return next(ctx, streamArgs, reqArgs, resArgs) + } + }), + ) + go func() { + _ = svr.Run() + }() + defer svr.Stop() + test.WaitServerStart(addr) + + streamClient, _ := NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + streamxclient.WithProvider(tc.ClientProvider), + ) + ctx := context.Background() + msg := "BidiStream" + + t.Logf("=== Checking only one connection be reused ===") + var wg sync.WaitGroup + for i := 0; i < 12; i++ { + wg.Add(1) + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + req := new(Request) + req.Message = string(make([]byte, 1024)) + err = bs.Send(ctx, req) + test.Assert(t, err == nil, err) + res, err := bs.Recv(ctx) + test.Assert(t, err == nil, err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == req.Message, res.Message) + runtime.SetFinalizer(bs, func(_ any) { + wg.Done() + }) + bs = nil + runtime.GC() + wg.Wait() + } + + t.Logf("=== Checking streams GCed ===") + streams := 100 + streamList := make([]streamx.ServerStream, streams) + atomic.StoreInt32(&streamStarted, 0) + for i := 0; i < streams; i++ { + ctx := context.Background() + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + streamList[i] = bs + } + waitStreamStarted(streams) + // before GC + ngBefore := runtime.NumGoroutine() + test.Assert(t, runtime.NumGoroutine() > streams, runtime.NumGoroutine()) + // after GC + for i := 0; i < streams; i++ { + streamList[i] = nil + } + for runtime.NumGoroutine() > ngBefore { + t.Logf("ngCurrent=%d > ngBefore=%d", runtime.NumGoroutine(), ngBefore) + runtime.GC() + time.Sleep(time.Millisecond * 50) + } + + t.Logf("=== Checking Streams Called and GCed ===") + streams = 100 + for i := 0; i < streams; i++ { + wg.Add(1) + go func() { + bs, err := streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + req := new(Request) + req.Message = msg + err = bs.Send(ctx, req) + test.Assert(t, err == nil, err) + go func() { + defer wg.Done() + res, err := bs.Recv(ctx) + test.Assert(t, err == nil, err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == msg, res.Message) + + testHeaderAndTrailer(t, bs) + }() + }() + } + wg.Wait() + }) + } +} diff --git a/server/streamxserver/server.go b/server/streamxserver/server.go index 890040e190..adfeb73432 100644 --- a/server/streamxserver/server.go +++ b/server/streamxserver/server.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamxserver import ( diff --git a/server/streamxserver/server_gen.go b/server/streamxserver/server_gen.go index f9ec27aad2..c92c8692b9 100644 --- a/server/streamxserver/server_gen.go +++ b/server/streamxserver/server_gen.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamxserver import ( @@ -14,7 +30,8 @@ var invokerCache sync.Map func InvokeStream[Req, Res any]( ctx context.Context, smode serviceinfo.StreamingMode, - handler any, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + handler any, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs, +) (err error) { // prepare args sArgs := streamx.GetStreamArgsFromContext(ctx) if sArgs == nil { diff --git a/server/streamxserver/server_option.go b/server/streamxserver/server_option.go index f9233703db..6869dcc8b5 100644 --- a/server/streamxserver/server_option.go +++ b/server/streamxserver/server_option.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamxserver import ( @@ -8,8 +24,10 @@ import ( "github.com/cloudwego/kitex/server" ) -type Option internal_server.Option -type Options = internal_server.Options +type ( + Option internal_server.Option + Options = internal_server.Options +) func WithListener(ln net.Listener) Option { return ConvertNativeServerOption(server.WithListener(ln)) diff --git a/tool/internal_pkg/tpl/streamx/client.go b/tool/internal_pkg/tpl/streamx/client.go index f88fce814e..2a5e88ca2b 100644 --- a/tool/internal_pkg/tpl/streamx/client.go +++ b/tool/internal_pkg/tpl/streamx/client.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx var ClientTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. diff --git a/tool/internal_pkg/tpl/streamx/handler.method.go b/tool/internal_pkg/tpl/streamx/handler.method.go index 0224237e3d..acbbdd7571 100644 --- a/tool/internal_pkg/tpl/streamx/handler.method.go +++ b/tool/internal_pkg/tpl/streamx/handler.method.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx var HandlerMethodsTpl = `{{define "HandlerMethod"}} diff --git a/tool/internal_pkg/tpl/streamx/server.go b/tool/internal_pkg/tpl/streamx/server.go index 217c7b3fa1..08eaa64483 100644 --- a/tool/internal_pkg/tpl/streamx/server.go +++ b/tool/internal_pkg/tpl/streamx/server.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx var ServerTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. diff --git a/tool/internal_pkg/tpl/streamx/service.go b/tool/internal_pkg/tpl/streamx/service.go index cbd727d5d0..d9bad47462 100644 --- a/tool/internal_pkg/tpl/streamx/service.go +++ b/tool/internal_pkg/tpl/streamx/service.go @@ -1,3 +1,19 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package streamx var ServiceTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. From 63b5605faae71b6924534d42ef06e7647064b915 Mon Sep 17 00:00:00 2001 From: Joway Date: Wed, 23 Oct 2024 15:04:37 +0800 Subject: [PATCH 10/34] feat: support client lifecycle latency (#1582) --- client/client_streamx.go | 18 ++- client/streamxclient/client_gen.go | 12 +- .../streamxcallopt/call_option.go | 42 +++++-- pkg/remote/trans/streamx/server_handler.go | 17 ++- pkg/streamx/client_provider.go | 3 +- pkg/streamx/client_provider_internal.go | 5 +- .../provider/jsonrpc/client_provier.go | 3 +- .../provider/ttstream/client_provier.go | 31 ++--- .../provider/ttstream/server_provider.go | 5 + pkg/streamx/provider/ttstream/stream.go | 4 +- pkg/streamx/provider/ttstream/stream_io.go | 42 ++++--- pkg/streamx/provider/ttstream/transport.go | 67 ++++++++-- pkg/streamx/streamx_gen_service_test.go | 28 ++--- pkg/streamx/streamx_user_test.go | 119 +++++++++++++----- 14 files changed, 261 insertions(+), 135 deletions(-) diff --git a/client/client_streamx.go b/client/client_streamx.go index 6a0fcb769c..73ca7264d3 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -25,7 +25,7 @@ import ( ) type StreamX interface { - NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) + NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) Middlewares() (streamMW streamx.StreamMiddleware, recvMW streamx.StreamRecvMiddleware, sendMW streamx.StreamSendMiddleware) } @@ -34,7 +34,7 @@ func (kc *kClient) Middlewares() (streamMW streamx.StreamMiddleware, recvMW stre } // NewStream create stream for streamx mode -func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) { +func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) { if !kc.inited { panic("client not initialized") } @@ -49,19 +49,25 @@ func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOp err := rpcinfo.AsMutableRPCConfig(ri.Config()).SetInteractionMode(rpcinfo.Streaming) if err != nil { - return nil, err + return nil, nil, err } ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) + + // tracing ctx = kc.opt.TracerCtl.DoStart(ctx, ri) + ctx, copts := streamxcallopt.NewCtxWithCallOptions(ctx) + callOptions = append(callOptions, streamxcallopt.WithStreamCloseCallback(func(ctx context.Context) { + kc.opt.TracerCtl.DoFinish(ctx, ri, err) + })) + copts.Apply(callOptions) streamArgs := streamx.NewStreamArgs(nil) // put streamArgs into response arg // it's an ugly trick but if we don't want to refactor too much, // this is the only way to compatible with current endpoint design err = kc.sEps(ctx, req, streamArgs) - kc.opt.TracerCtl.DoFinish(ctx, ri, err) if err != nil { - return nil, err + return nil, nil, err } - return streamArgs.Stream().(streamx.ClientStream), nil + return ctx, streamArgs.Stream().(streamx.ClientStream), nil } diff --git a/client/streamxclient/client_gen.go b/client/streamxclient/client_gen.go index 001d03c149..6c1d4af974 100644 --- a/client/streamxclient/client_gen.go +++ b/client/streamxclient/client_gen.go @@ -28,7 +28,7 @@ import ( func InvokeStream[Req, Res any]( ctx context.Context, cli client.StreamX, smode serviceinfo.StreamingMode, method string, req *Req, res *Res, callOptions ...streamxcallopt.CallOption, -) (stream *streamx.GenericClientStream[Req, Res], err error) { +) (context.Context, *streamx.GenericClientStream[Req, Res], error) { reqArgs, resArgs := streamx.NewStreamReqArgs(nil), streamx.NewStreamResArgs(nil) streamArgs := streamx.NewStreamArgs(nil) // important notes: please don't set a typed nil value into interface arg like NewStreamReqArgs({typ: *Res, ptr: nil}) @@ -40,11 +40,11 @@ func InvokeStream[Req, Res any]( resArgs.SetRes(res) } - cs, err := cli.NewStream(ctx, method, req, callOptions...) + ctx, cs, err := cli.NewStream(ctx, method, req, callOptions...) if err != nil { - return nil, err + return nil, nil, err } - stream = streamx.NewGenericClientStream[Req, Res](cs) + stream := streamx.NewGenericClientStream[Req, Res](cs) streamx.AsMutableStreamArgs(streamArgs).SetStream(stream) streamMW, recvMW, sendMW := cli.Middlewares() @@ -85,7 +85,7 @@ func InvokeStream[Req, Res any]( err = streamInvoke(ctx, streamArgs, reqArgs, resArgs) } if err != nil { - return nil, err + return nil, nil, err } - return stream, nil + return ctx, stream, nil } diff --git a/client/streamxclient/streamxcallopt/call_option.go b/client/streamxclient/streamxcallopt/call_option.go index de48a90ac9..5e2325ab0f 100644 --- a/client/streamxclient/streamxcallopt/call_option.go +++ b/client/streamxclient/streamxcallopt/call_option.go @@ -17,24 +17,48 @@ package streamxcallopt import ( - "fmt" - "strings" - "time" + "context" ) +type StreamCloseCallback func(ctx context.Context) + type CallOptions struct { - RPCTimeout time.Duration + StreamCloseCallback StreamCloseCallback } type CallOption struct { - f func(o *CallOptions, di *strings.Builder) + f func(o *CallOptions) } type WithCallOption func(o *CallOption) -func WithRPCTimeout(rpcTimeout time.Duration) CallOption { - return CallOption{f: func(o *CallOptions, di *strings.Builder) { - di.WriteString(fmt.Sprintf("WithRPCTimeout(%d)", rpcTimeout)) - o.RPCTimeout = rpcTimeout +type ctxKeyCallOptions struct{} + +func NewCtxWithCallOptions(ctx context.Context) (context.Context, *CallOptions) { + copts := new(CallOptions) + return context.WithValue(ctx, ctxKeyCallOptions{}, copts), copts +} + +func GetCallOptionsFromCtx(ctx context.Context) *CallOptions { + v := ctx.Value(ctxKeyCallOptions{}) + if v == nil { + return nil + } + copts, ok := v.(*CallOptions) + if !ok { + return nil + } + return copts +} + +func (copts *CallOptions) Apply(opts []CallOption) { + for _, opt := range opts { + opt.f(copts) + } +} + +func WithStreamCloseCallback(callback StreamCloseCallback) CallOption { + return CallOption{f: func(o *CallOptions) { + o.StreamCloseCallback = callback }} } diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index ede3ec40bb..99a0ae069c 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -22,6 +22,7 @@ import ( "io" "net" "runtime/debug" + "sync" "time" "github.com/cloudwego/kitex/internal/wpool" @@ -92,7 +93,18 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context. return ctx, nil } -func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { +// OnRead control the connection level lifecycle. +// only when OnRead return, netpoll can close the connection buffer +func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) { + var wg sync.WaitGroup + defer func() { + wg.Wait() + klog.CtxErrorf(ctx, "KITEX: stream OnRead return: err=%v", err) + _, nerr := t.provider.OnInactive(ctx, conn) + if err == nil && nerr != nil { + err = nerr + } + }() // connection level goroutine for { nctx, ss, nerr := t.provider.OnStream(ctx, conn) @@ -103,8 +115,10 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { klog.CtxErrorf(ctx, "KITEX: OnStream failed: err=%v", nerr) return nerr } + wg.Add(1) // stream level goroutine streamWorkerPool.GoCtx(ctx, func() { + defer wg.Done() err := t.OnStream(nctx, conn, ss) if err != nil && !errors.Is(err, io.EOF) { klog.CtxErrorf(ctx, "KITEX: stream ReadStream failed: err=%v", err) @@ -178,7 +192,6 @@ func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Me } func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { - _, _ = t.provider.OnInactive(ctx, conn) } func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { diff --git a/pkg/streamx/client_provider.go b/pkg/streamx/client_provider.go index 7706b38887..1ed92c5ae6 100644 --- a/pkg/streamx/client_provider.go +++ b/pkg/streamx/client_provider.go @@ -19,7 +19,6 @@ package streamx import ( "context" - "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/rpcinfo" ) @@ -42,5 +41,5 @@ res := stream.Recv(...) type ClientProvider interface { // NewStream create a stream based on rpcinfo and callOptions - NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (ClientStream, error) + NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (ClientStream, error) } diff --git a/pkg/streamx/client_provider_internal.go b/pkg/streamx/client_provider_internal.go index d28b83d55b..ba3f4da6f0 100644 --- a/pkg/streamx/client_provider_internal.go +++ b/pkg/streamx/client_provider_internal.go @@ -19,7 +19,6 @@ package streamx import ( "context" - "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/rpcinfo" ) @@ -31,8 +30,8 @@ type internalClientProvider struct { ClientProvider } -func (p internalClientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (ClientStream, error) { - cs, err := p.ClientProvider.NewStream(ctx, ri, callOptions...) +func (p internalClientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (ClientStream, error) { + cs, err := p.ClientProvider.NewStream(ctx, ri) if err != nil { return nil, err } diff --git a/pkg/streamx/provider/jsonrpc/client_provier.go b/pkg/streamx/provider/jsonrpc/client_provier.go index 304187a1be..987206920d 100644 --- a/pkg/streamx/provider/jsonrpc/client_provier.go +++ b/pkg/streamx/provider/jsonrpc/client_provier.go @@ -20,7 +20,6 @@ import ( "context" "net" - "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -43,7 +42,7 @@ type clientProvider struct { payloadLimit int } -func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) { +func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (streamx.ClientStream, error) { invocation := ri.Invocation() method := invocation.MethodName() addr := ri.To().Address() diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index f79e760151..119d6bdacd 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -19,12 +19,10 @@ package ttstream import ( "context" "runtime" - "sync/atomic" "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/gopkg/protocol/ttheader" - "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -51,7 +49,7 @@ type clientProvider struct { headerHandler HeaderFrameHandler } -func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callOptions ...streamxcallopt.CallOption) (streamx.ClientStream, error) { +func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (streamx.ClientStream, error) { rconfig := ri.Config() invocation := ri.Invocation() method := invocation.MethodName() @@ -95,31 +93,16 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callO }) cs := newClientStream(sio.stream) - // the END of a client stream means it should send and recv trailer and not hold by user anymore - var ended uint32 - sio.setEOFCallback(func() { - // if stream is ended by both parties, put the transport back to pool - if atomic.AddUint32(&ended, 1) == 2 { - _ = c.streamFinalize(sio, trans) - } - }) runtime.SetFinalizer(cs, func(cstream *clientStream) { // it's safe to call CloseSend twice - // we do repeated CloseSend here to ensure stream can be closed normally + // we do CloseSend here to ensure stream can be closed normally _ = cstream.CloseSend(ctx) - // only delete stream when clientStream be finalized - if atomic.AddUint32(&ended, 1) == 2 { - _ = c.streamFinalize(sio, trans) + + sio.close() + trans.streamDelete(sio.stream.sid) + if trans.IsActive() { + c.transPool.Put(trans) } }) return cs, err } - -func (c clientProvider) streamFinalize(sio *streamIO, trans *transport) error { - sio.close() - err := trans.streamDelete(sio.stream.sid) - if trans.IsActive() { - c.transPool.Put(trans) - } - return err -} diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go index 5b62dea78f..6313d8884f 100644 --- a/pkg/streamx/provider/ttstream/server_provider.go +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -74,6 +74,11 @@ func (s serverProvider) OnActive(ctx context.Context, conn net.Conn) (context.Co } func (s serverProvider) OnInactive(ctx context.Context, conn net.Conn) (context.Context, error) { + trans, _ := ctx.Value(serverTransCtxKey{}).(*transport) + if trans == nil { + return ctx, nil + } + trans.WaitClosed() return ctx, nil } diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 3bdb35dc25..0dd2e02759 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -279,7 +279,7 @@ func (s *serverStream) close(ex tException) error { if err != nil { return err } - err = s.trans.streamDelete(s.sid) + s.trans.streamDelete(s.sid) s.stream.close() - return err + return nil } diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go index 8eabdd5076..4e4b21f359 100644 --- a/pkg/streamx/provider/ttstream/stream_io.go +++ b/pkg/streamx/provider/ttstream/stream_io.go @@ -32,17 +32,15 @@ type streamIOMsg struct { } type streamIO struct { - ctx context.Context - trigger chan struct{} - stream *stream - pipe *container.Pipe[streamIOMsg] - cache [1]streamIOMsg - exception error // once has exception, the stream should not work normally again - // eofFlag == 2 when both parties send trailers - eofFlag int32 - // eofCallback will be called when eofFlag == 2 - // eofCallback will not be called if stream is not be ended in a normal way - eofCallback func() + ctx context.Context + trigger chan struct{} + stream *stream + pipe *container.Pipe[streamIOMsg] + cache [1]streamIOMsg + exception error // once has exception, the stream should not work normally again + eofFlag int32 + callbackFlag int32 + closeCallback func(ctx context.Context) } func newStreamIO(ctx context.Context, s *stream) *streamIO { @@ -54,10 +52,6 @@ func newStreamIO(ctx context.Context, s *stream) *streamIO { return sio } -func (s *streamIO) setEOFCallback(f func()) { - s.eofCallback = f -} - func (s *streamIO) input(ctx context.Context, msg streamIOMsg) { err := s.pipe.Write(ctx, msg) if err != nil { @@ -90,25 +84,33 @@ func (s *streamIO) output(ctx context.Context) (msg streamIOMsg, err error) { return msg, nil } +func (s *streamIO) runCloseCallback() { + if s.closeCallback != nil && atomic.CompareAndSwapInt32(&s.callbackFlag, 0, 1) { + s.closeCallback(s.ctx) + } +} + func (s *streamIO) closeRecv() { s.pipe.Close() - if atomic.AddInt32(&s.eofFlag, 1) == 2 && s.eofCallback != nil { - s.eofCallback() + if s.closeCallback != nil && atomic.AddInt32(&s.eofFlag, 1) == 2 { + s.runCloseCallback() } } func (s *streamIO) closeSend() { - if atomic.AddInt32(&s.eofFlag, 1) == 2 && s.eofCallback != nil { - s.eofCallback() + if s.closeCallback != nil && atomic.AddInt32(&s.eofFlag, 1) == 2 { + s.runCloseCallback() } } func (s *streamIO) close() { - s.stream.close() s.pipe.Close() + s.stream.close() + s.runCloseCallback() } func (s *streamIO) cancel() { s.pipe.Cancel() s.stream.close() + s.runCloseCallback() } diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index dc97d62f09..97a58c7797 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -30,6 +30,10 @@ import ( "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/rpcinfo" + + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" @@ -58,6 +62,7 @@ type transport struct { fpipe *container.Pipe[*Frame] // out-coming frame pipe closedFlag int32 streamingFlag int32 // flag == 0 means there is no active stream on transport + closedTrigger chan struct{} } func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection) *transport { @@ -65,15 +70,19 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne // TODO: let it configurable _ = conn.SetReadTimeout(time.Minute * 10) t := &transport{ - kind: kind, - sinfo: sinfo, - conn: conn, - streams: sync.Map{}, - spipe: container.NewPipe[*stream](), - scache: make([]*stream, 0, streamCacheSize), - fpipe: container.NewPipe[*Frame](), + kind: kind, + sinfo: sinfo, + conn: conn, + streams: sync.Map{}, + spipe: container.NewPipe[*stream](), + scache: make([]*stream, 0, streamCacheSize), + fpipe: container.NewPipe[*Frame](), + closedTrigger: make(chan struct{}, 2), } go func() { + defer func() { + t.closedTrigger <- struct{}{} + }() err := t.loopRead() if err != nil { if !isIgnoreError(err) { @@ -85,6 +94,9 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne } }() go func() { + defer func() { + t.closedTrigger <- struct{}{} + }() err := t.loopWrite() if err != nil { if !isIgnoreError(err) { @@ -103,18 +115,28 @@ func (t *transport) Close() (err error) { if !atomic.CompareAndSwapInt32(&t.closedFlag, 0, 1) { return nil } - klog.Debugf("transport[%s] is closing", t.conn.LocalAddr()) + switch t.kind { + case clientTransport: + klog.Debugf("transport[%d-%s] is closing", t.kind, t.conn.LocalAddr()) + case serverTransport: + klog.Debugf("transport[%d-%s] is closing", t.kind, t.conn.RemoteAddr()) + } t.spipe.Close() t.fpipe.Close() t.streams.Range(func(key, value any) bool { sio := value.(*streamIO) sio.close() - _ = t.streamDelete(sio.stream.sid) + t.streamDelete(sio.stream.sid) return true }) return err } +func (t *transport) WaitClosed() { + <-t.closedTrigger + <-t.closedTrigger +} + func (t *transport) IsActive() bool { return atomic.LoadInt32(&t.closedFlag) == 0 && t.conn.IsActive() } @@ -122,6 +144,11 @@ func (t *transport) IsActive() bool { func (t *transport) storeStreamIO(ctx context.Context, s *stream) *streamIO { sio := newStreamIO(ctx, s) t.streams.Store(s.sid, sio) + + copts := streamxcallopt.GetCallOptionsFromCtx(ctx) + if copts != nil && copts.StreamCloseCallback != nil { + sio.closeCallback = copts.StreamCloseCallback + } return sio } @@ -222,7 +249,7 @@ func (t *transport) loopWrite() error { } for i := 0; i < n; i++ { fr := fcache[i] - klog.Debugf("transport[%d] EncodeFrame: fr=%v", t.kind, fr) + klog.Debugf("transport[%d] EncodeFrame: fr=%v IsActive=%v", t.kind, fr, t.conn.IsActive()) if err = EncodeFrame(context.Background(), writer, fr); err != nil { return err } @@ -251,6 +278,13 @@ func (t *transport) streamSend(ctx context.Context, sid int32, method string, wh if err != nil { return err } + // tracing + ri := rpcinfo.GetRPCInfo(ctx) + if ri != nil && ri.Stats() != nil { + if rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()); rpcStats != nil { + rpcStats.IncrSendSize(uint64(len(payload))) + } + } return t.writeFrame( streamFrame{sid: sid, method: method}, nil, dataFrameType, payload, @@ -298,6 +332,14 @@ func (t *transport) streamRecv(ctx context.Context, sid int32, data any) (err er err = DecodePayload(context.Background(), msg.payload, data.(thrift.FastCodec)) // payload will not be access after decode mcache.Free(msg.payload) + + // tracing + ri := rpcinfo.GetRPCInfo(ctx) + if ri != nil && ri.Stats() != nil { + if rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()); rpcStats != nil { + rpcStats.IncrRecvSize(uint64(len(msg.payload))) + } + } return err } @@ -313,14 +355,13 @@ func (t *transport) streamCloseRecv(s *stream, exception error) error { return nil } -func (t *transport) streamDelete(sid int32) (err error) { +func (t *transport) streamDelete(sid int32) { // remove stream from transport _, ok := t.streams.LoadAndDelete(sid) if !ok { - return nil + return } atomic.AddInt32(&t.streamingFlag, -1) - return nil } func (t *transport) IsStreaming() bool { diff --git a/pkg/streamx/streamx_gen_service_test.go b/pkg/streamx/streamx_gen_service_test.go index ea774a2cec..5c049ec79f 100644 --- a/pkg/streamx/streamx_gen_service_test.go +++ b/pkg/streamx/streamx_gen_service_test.go @@ -202,18 +202,18 @@ type PingPongClientInterface interface { type StreamingClientInterface interface { Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream streamx.ClientStreamingClient[Request, Response], err error) + context.Context, streamx.ClientStreamingClient[Request, Response], error) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream streamx.ServerStreamingClient[Response], err error) + context.Context, streamx.ServerStreamingClient[Response], error) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream streamx.BidiStreamingClient[Request, Response], err error) + context.Context, streamx.BidiStreamingClient[Request, Response], error) UnaryWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) ClientStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream streamx.ClientStreamingClient[Request, Response], err error) + context.Context, streamx.ClientStreamingClient[Request, Response], error) ServerStreamWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream streamx.ServerStreamingClient[Response], err error) + context.Context, streamx.ServerStreamingClient[Response], error) BidiStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream streamx.BidiStreamingClient[Request, Response], err error) + context.Context, streamx.BidiStreamingClient[Request, Response], error) } // --- Define Client Implementation --- @@ -239,7 +239,7 @@ func (c *kClient) PingPong(ctx context.Context, req *Request) (r *Response, err func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { res := new(Response) - _, err := streamxclient.InvokeStream[Request, Response]( + _, _, err := streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingUnary, "Unary", req, res, callOptions...) if err != nil { return nil, err @@ -248,21 +248,21 @@ func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...stream } func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream streamx.ClientStreamingClient[Request, Response], err error, + context.Context, streamx.ClientStreamingClient[Request, Response], error, ) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) } func (c *kClient) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream streamx.ServerStreamingClient[Response], err error, + context.Context, streamx.ServerStreamingClient[Response], error, ) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) } func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream streamx.BidiStreamingClient[Request, Response], err error, + context.Context, streamx.BidiStreamingClient[Request, Response], error, ) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) @@ -270,7 +270,7 @@ func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt. func (c *kClient) UnaryWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { res := new(Response) - _, err := streamxclient.InvokeStream[Request, Response]( + _, _, err := streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingUnary, "UnaryWithErr", req, res, callOptions...) if err != nil { return nil, err @@ -279,21 +279,21 @@ func (c *kClient) UnaryWithErr(ctx context.Context, req *Request, callOptions .. } func (c *kClient) ClientStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream streamx.ClientStreamingClient[Request, Response], err error, + context.Context, streamx.ClientStreamingClient[Request, Response], error, ) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingClient, "ClientStreamWithErr", nil, nil, callOptions...) } func (c *kClient) ServerStreamWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( - stream streamx.ServerStreamingClient[Response], err error, + context.Context, streamx.ServerStreamingClient[Response], error, ) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingServer, "ServerStreamWithErr", req, nil, callOptions...) } func (c *kClient) BidiStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( - stream streamx.BidiStreamingClient[Request, Response], err error, + context.Context, streamx.BidiStreamingClient[Request, Response], error, ) { return streamxclient.InvokeStream[Request, Response]( ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStreamWithErr", nil, nil, callOptions...) diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index 8d5a1b92ef..bccac99536 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -19,6 +19,7 @@ package streamx_test import ( "context" "errors" + "fmt" "io" "log" "net/http" @@ -34,8 +35,8 @@ import ( "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/streamxclient" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" "github.com/cloudwego/kitex/server" @@ -52,8 +53,6 @@ type testCase struct { } func init() { - klog.SetLevel(klog.LevelWarn) - sp, _ := ttstream.NewServerProvider(streamingServiceInfo) cp, _ := ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientLongConnPool(ttstream.LongConnConfig{MaxIdleTimeout: time.Millisecond * 100})) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_LongConn", ClientProvider: cp, ServerProvider: sp}) @@ -67,6 +66,7 @@ func TestMain(m *testing.M) { go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) }() + // klog.SetLevel(klog.LevelDebug) m.Run() } @@ -221,13 +221,12 @@ func TestStreamingBasic(t *testing.T) { test.Assert(t, err == nil, err) // prepare metainfo - ctx := context.Background() - ctx = setMetadata(ctx) + octx := setMetadata(context.Background()) t.Logf("=== PingPong ===") req := new(Request) req.Message = "PingPong" - res, err := pingpongClient.PingPong(ctx, req) + res, err := pingpongClient.PingPong(octx, req) test.Assert(t, err == nil, err) test.Assert(t, req.Message == res.Message, res) @@ -235,7 +234,7 @@ func TestStreamingBasic(t *testing.T) { req = new(Request) req.Type = 10000 req.Message = "Unary" - res, err = streamClient.Unary(ctx, req) + res, err = streamClient.Unary(octx, req) test.Assert(t, err == nil, err) test.Assert(t, req.Type == res.Type, res.Type) test.Assert(t, req.Message == res.Message, res.Message) @@ -249,7 +248,7 @@ func TestStreamingBasic(t *testing.T) { // client stream round := 5 t.Logf("=== ClientStream ===") - cs, err := streamClient.ClientStream(ctx) + ctx, cs, err := streamClient.ClientStream(octx) test.Assert(t, err == nil, err) for i := 0; i < round; i++ { req := new(Request) @@ -275,7 +274,7 @@ func TestStreamingBasic(t *testing.T) { t.Logf("=== ServerStream ===") req = new(Request) req.Message = "ServerStream" - ss, err := streamClient.ServerStream(ctx, req) + ctx, ss, err := streamClient.ServerStream(octx, req) test.Assert(t, err == nil, err) received := 0 for { @@ -306,7 +305,7 @@ func TestStreamingBasic(t *testing.T) { for c := 0; c < concurrent; c++ { atomic.AddInt32(&serverStreamCount, -1) go func() { - bs, err := streamClient.BidiStream(ctx) + ctx, bs, err := streamClient.BidiStream(octx) test.Assert(t, err == nil, err) msg := "BidiStream" var wg sync.WaitGroup @@ -363,7 +362,7 @@ func TestStreamingBasic(t *testing.T) { assertBizErr(t, err) t.Logf("=== ClientStreamWithErr normalErr ===") - cliStream, err := streamClient.ClientStreamWithErr(ctx) + ctx, cliStream, err := streamClient.ClientStreamWithErr(octx) test.Assert(t, err == nil, err) test.Assert(t, cliStream != nil, cliStream) req = new(Request) @@ -376,7 +375,7 @@ func TestStreamingBasic(t *testing.T) { assertNormalErr(t, err) t.Logf("=== ClientStreamWithErr bizErr ===") - cliStream, err = streamClient.ClientStreamWithErr(ctx) + ctx, cliStream, err = streamClient.ClientStreamWithErr(octx) test.Assert(t, err == nil, err) test.Assert(t, cliStream != nil, cliStream) req = new(Request) @@ -391,7 +390,7 @@ func TestStreamingBasic(t *testing.T) { t.Logf("=== ServerStreamWithErr normalErr ===") req = new(Request) req.Type = normalErr - svrStream, err := streamClient.ServerStreamWithErr(ctx, req) + ctx, svrStream, err := streamClient.ServerStreamWithErr(octx, req) test.Assert(t, err == nil, err) test.Assert(t, svrStream != nil, svrStream) res, err = svrStream.Recv(ctx) @@ -402,7 +401,7 @@ func TestStreamingBasic(t *testing.T) { t.Logf("=== ServerStreamWithErr bizErr ===") req = new(Request) req.Type = bizErr - svrStream, err = streamClient.ServerStreamWithErr(ctx, req) + ctx, svrStream, err = streamClient.ServerStreamWithErr(octx, req) test.Assert(t, err == nil, err) test.Assert(t, svrStream != nil, svrStream) res, err = svrStream.Recv(ctx) @@ -411,7 +410,7 @@ func TestStreamingBasic(t *testing.T) { assertBizErr(t, err) t.Logf("=== BidiStreamWithErr normalErr ===") - bidiStream, err := streamClient.BidiStreamWithErr(ctx) + ctx, bidiStream, err := streamClient.BidiStreamWithErr(octx) test.Assert(t, err == nil, err) test.Assert(t, bidiStream != nil, bidiStream) req = new(Request) @@ -424,7 +423,7 @@ func TestStreamingBasic(t *testing.T) { assertNormalErr(t, err) t.Logf("=== BidiStreamWithErr bizErr ===") - bidiStream, err = streamClient.BidiStreamWithErr(ctx) + ctx, bidiStream, err = streamClient.BidiStreamWithErr(octx) test.Assert(t, err == nil, err) test.Assert(t, bidiStream != nil, bidiStream) req = new(Request) @@ -437,7 +436,7 @@ func TestStreamingBasic(t *testing.T) { assertBizErr(t, err) t.Logf("=== Timeout by Ctx ===") - bs, err := streamClient.BidiStream(ctx) + ctx, bs, err := streamClient.BidiStream(octx) test.Assert(t, err == nil, err) req = new(Request) req.Message = string(make([]byte, 1024)) @@ -459,7 +458,7 @@ func TestStreamingBasic(t *testing.T) { streamxclient.WithProvider(tc.ClientProvider), streamxclient.WithRecvTimeout(time.Nanosecond), ) - bs, err = streamClient.BidiStream(ctx) + ctx, bs, err = streamClient.BidiStream(octx) test.Assert(t, err == nil, err) req = new(Request) req.Message = string(make([]byte, 1024)) @@ -517,14 +516,14 @@ func TestStreamingGoroutineLeak(t *testing.T) { streamxclient.WithHostPorts(addr), streamxclient.WithProvider(tc.ClientProvider), ) - ctx := context.Background() + octx := context.Background() msg := "BidiStream" t.Logf("=== Checking only one connection be reused ===") var wg sync.WaitGroup for i := 0; i < 12; i++ { wg.Add(1) - bs, err := streamClient.BidiStream(ctx) + ctx, bs, err := streamClient.BidiStream(octx) test.Assert(t, err == nil, err) req := new(Request) req.Message = string(make([]byte, 1024)) @@ -548,8 +547,7 @@ func TestStreamingGoroutineLeak(t *testing.T) { streamList := make([]streamx.ServerStream, streams) atomic.StoreInt32(&streamStarted, 0) for i := 0; i < streams; i++ { - ctx := context.Background() - bs, err := streamClient.BidiStream(ctx) + _, bs, err := streamClient.BidiStream(octx) test.Assert(t, err == nil, err) streamList[i] = bs } @@ -572,25 +570,82 @@ func TestStreamingGoroutineLeak(t *testing.T) { for i := 0; i < streams; i++ { wg.Add(1) go func() { - bs, err := streamClient.BidiStream(ctx) + defer wg.Done() + + ctx, bs, err := streamClient.BidiStream(octx) test.Assert(t, err == nil, err) req := new(Request) req.Message = msg err = bs.Send(ctx, req) test.Assert(t, err == nil, err) - go func() { - defer wg.Done() - res, err := bs.Recv(ctx) - test.Assert(t, err == nil, err) - err = bs.CloseSend(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, res.Message == msg, res.Message) - testHeaderAndTrailer(t, bs) - }() + res, err := bs.Recv(ctx) + test.Assert(t, err == nil, err) + err = bs.CloseSend(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == msg, res.Message) + + testHeaderAndTrailer(t, bs) }() } wg.Wait() }) } } + +func TestStreamingException(t *testing.T) { + for _, tc := range providerTestCases { + t.Run(tc.Name, func(t *testing.T) { + addr := test.GetLocalAddress() + ln, _ := netpoll.CreateListener("tcp", addr) + defer ln.Close() + + // create server + svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + _ = svr.RegisterService( + streamingServiceInfo, new(streamingService), + streamxserver.WithProvider(tc.ServerProvider), + ) + go func() { + _ = svr.Run() + }() + defer svr.Stop() + test.WaitServerStart(addr) + + var circuitBreaker int32 + circuitBreakerErr := fmt.Errorf("circuitBreaker on") + streamClient, _ := NewStreamingClient( + "kitex.service.streaming", + streamxclient.WithHostPorts(addr), + streamxclient.WithProvider(tc.ClientProvider), + streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + ri := rpcinfo.GetRPCInfo(ctx) + test.Assert(t, ri.To().Address() != nil) + if atomic.LoadInt32(&circuitBreaker) > 0 { + return circuitBreakerErr + } + return next(ctx, streamArgs, reqArgs, resArgs) + } + }), + ) + octx := context.Background() + + // assert circuitBreaker error + atomic.StoreInt32(&circuitBreaker, 1) + ctx, bs, err := streamClient.BidiStream(octx) + test.Assert(t, errors.Is(err, circuitBreakerErr), err) + atomic.StoreInt32(&circuitBreaker, 0) + + // assert context deadline error + ctx, cancel := context.WithTimeout(octx, time.Millisecond) + ctx, bs, err = streamClient.BidiStream(ctx) + test.Assert(t, err == nil, err) + res, err := bs.Recv(ctx) + cancel() + test.Assert(t, res == nil && err != nil, res, err) + test.Assert(t, errors.Is(err, ctx.Err()), err) + test.Assert(t, errors.Is(err, context.DeadlineExceeded), err) + }) + } +} From 9cd69061990627ef3605315de82be0e782556ea6 Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Tue, 29 Oct 2024 11:51:42 +0800 Subject: [PATCH 11/34] feat: support TTHeader Streaming detailed error handling (#1594) --- .../provider/ttstream/client_provier.go | 13 +- .../ttstream/client_trans_pool_mux.go | 7 +- .../ttstream/client_trans_pool_shortconn.go | 4 +- .../ttstream/container/object_pool.go | 4 +- .../provider/ttstream/error_scenario_test.go | 246 ++++++++++++++++++ .../provider/ttstream/errors/errors.go | 49 ++++ pkg/streamx/provider/ttstream/frame.go | 15 +- .../provider/ttstream/server_provider.go | 2 +- pkg/streamx/provider/ttstream/stream.go | 49 +++- pkg/streamx/provider/ttstream/stream_io.go | 23 +- pkg/streamx/provider/ttstream/transport.go | 124 ++++----- pkg/streamx/streamx_common_test.go | 9 +- 12 files changed, 432 insertions(+), 113 deletions(-) create mode 100644 pkg/streamx/provider/ttstream/error_scenario_test.go create mode 100644 pkg/streamx/provider/ttstream/errors/errors.go diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index 119d6bdacd..ba78a1931f 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -78,28 +78,27 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (stre return nil, err } - sio, err := trans.newStreamIO(ctx, method, intHeader, strHeader) + s, err := trans.newStream(ctx, method, intHeader, strHeader) if err != nil { return nil, err } - sio.stream.setRecvTimeout(rconfig.StreamRecvTimeout()) + s.setRecvTimeout(rconfig.StreamRecvTimeout()) // only client can set meta frame handler - sio.stream.setMetaFrameHandler(c.metaHandler) + s.setMetaFrameHandler(c.metaHandler) // if ctx from server side, we should cancel the stream when server handler already returned // TODO: this canceling transmit should be configurable ktx.RegisterCancelCallback(ctx, func() { - sio.cancel() + s.cancel() }) - cs := newClientStream(sio.stream) + cs := newClientStream(s) runtime.SetFinalizer(cs, func(cstream *clientStream) { // it's safe to call CloseSend twice // we do CloseSend here to ensure stream can be closed normally _ = cstream.CloseSend(ctx) - sio.close() - trans.streamDelete(sio.stream.sid) + s.close(nil) if trans.IsActive() { c.transPool.Put(trans) } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go index d0d9fc8a60..b16acab47e 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go @@ -17,6 +17,7 @@ package ttstream import ( + "errors" "runtime" "sync" "sync/atomic" @@ -26,6 +27,7 @@ import ( "golang.org/x/sync/singleflight" "github.com/cloudwego/kitex/pkg/serviceinfo" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" ) var _ transPool = (*muxTransPool)(nil) @@ -60,12 +62,13 @@ func (tl *muxTransList) Get(sinfo *serviceinfo.ServiceInfo, network, addr string trans = newTransport(clientTransport, sinfo, conn) _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { // peer close - _ = trans.Close() + _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("netpoll connection closed"))) return nil }) runtime.SetFinalizer(trans, func(trans *transport) { // self close when not hold by user - _ = trans.Close() + // todo: think about a more ideal error + _ = trans.Close(nil) }) tl.L.Lock() tl.transports[idx] = trans diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go index e5eb081520..a5c0661a0a 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go @@ -17,9 +17,11 @@ package ttstream import ( + "errors" "time" "github.com/cloudwego/kitex/pkg/serviceinfo" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" ) func newShortConnTransPool() transPool { @@ -40,5 +42,5 @@ func (c *shortConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr s } func (c *shortConnTransPool) Put(trans *transport) { - _ = trans.Close() + _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("short connection closed"))) } diff --git a/pkg/streamx/provider/ttstream/container/object_pool.go b/pkg/streamx/provider/ttstream/container/object_pool.go index 51658e706e..8450f357fa 100644 --- a/pkg/streamx/provider/ttstream/container/object_pool.go +++ b/pkg/streamx/provider/ttstream/container/object_pool.go @@ -23,7 +23,7 @@ import ( ) type Object interface { - Close() error + Close(exception error) error } type objectItem struct { @@ -106,7 +106,7 @@ func (s *ObjectPool) cleaning() { return false, false } deleted++ - _ = o.object.Close() + _ = o.object.Close(nil) return true, true }) } diff --git a/pkg/streamx/provider/ttstream/error_scenario_test.go b/pkg/streamx/provider/ttstream/error_scenario_test.go new file mode 100644 index 0000000000..1ed00344f9 --- /dev/null +++ b/pkg/streamx/provider/ttstream/error_scenario_test.go @@ -0,0 +1,246 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/bytedance/gopkg/cloud/metainfo" + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" +) + +const ( + testTypeKey = "testType" + + testTypeIllegalFrame = "illegalFrame" + testTypeUnexpectedHeaderFrame = "unexpectedHeaderFrame" + testTypeUnexpectedTrailerFrame = "unexpectedTrailerFrame" + testTypeIllegalBizErr = "illegalBizErr" + testTypeApplicationException = "applicationException" +) + +var streamingServiceInfo = &serviceinfo.ServiceInfo{ + ServiceName: "kitex.service.streaming", + Methods: map[string]serviceinfo.MethodInfo{ + "TriggerStreamErr": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return nil + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), + }, + Extra: map[string]interface{}{"streamingFlag": true, "streamx": true}, +} + +type illegalFrameType int32 + +const ( + MissStreamingFlag illegalFrameType = iota +) + +func encodeIllegalFrame(t *testing.T, ctx context.Context, writer bufiox.Writer, fr *Frame, flag illegalFrameType) { + var param ttheader.EncodeParam + written := writer.WrittenLen() + switch flag { + case MissStreamingFlag: + param = ttheader.EncodeParam{ + SeqID: fr.sid, + ProtocolID: ttheader.ProtocolIDThriftStruct, + } + param.IntInfo = fr.meta + if param.IntInfo == nil { + param.IntInfo = make(IntHeader) + } + param.IntInfo[ttheader.FrameType] = frameTypeToString[fr.typ] + param.IntInfo[ttheader.ToMethod] = fr.method + totalLenField, err := ttheader.Encode(ctx, param, writer) + if err != nil { + t.Errorf("ttheader Encode failed, err: %v", err) + } + written = writer.WrittenLen() - written + binary.BigEndian.PutUint32(totalLenField, uint32(written-4)) + } +} + +func TestErrorScenario(t *testing.T) { + klog.SetLevel(klog.LevelDebug) + addr := test.GetLocalAddress() + nAddr, err := net.ResolveTCPAddr("tcp", addr) + test.Assert(t, err == nil, err) + ln, err := netpoll.CreateListener("tcp", addr) + test.Assert(t, err == nil, err) + defer ln.Close() + sp, err := NewServerProvider(streamingServiceInfo) + test.Assert(t, err == nil, err) + onConnect := func(ctx context.Context, conn netpoll.Connection) context.Context { + nctx, err := sp.OnActive(ctx, conn) + test.Assert(t, err == nil, err) + nctx, ss, nerr := sp.OnStream(nctx, conn) + test.Assert(t, nerr == nil, nerr) + go func() { + rawss := ss.(*serverStream) + testType, ok := metainfo.GetValue(nctx, testTypeKey) + test.Assert(t, ok) + switch testType { + case testTypeIllegalFrame: + encodeIllegalFrame(t, nctx, newWriterBuffer(rawss.trans.conn.Writer()), &Frame{ + streamFrame: streamFrame{ + sid: rawss.sid, + }, + typ: headerFrameType, + }, MissStreamingFlag) + rawss.trans.conn.Writer().Flush() + case testTypeUnexpectedHeaderFrame: + hd := streamx.Header{ + "key": "val", + } + rawss.trans.streamSendHeader(rawss.stream, hd) + rawss.trans.streamSendHeader(rawss.stream, hd) + case testTypeUnexpectedTrailerFrame: + rawss.trans.streamCloseSend(rawss.stream, nil, nil) + rawss.trans.streamCloseSend(rawss.stream, nil, nil) + case testTypeIllegalBizErr: + err = rawss.writeTrailer( + streamx.Trailer{ + "biz-status": "1", + "biz-message": "message", + "biz-extra": "invalid extra JSON str", + }, + ) + test.Assert(t, err == nil, err) + err = rawss.sendTrailer(nctx, nil) + test.Assert(t, err == nil, err) + case testTypeApplicationException: + exception := thrift.NewApplicationException(remote.InternalError, "testApplicationException") + err = rawss.sendTrailer(nctx, exception) + test.Assert(t, err == nil, err) + } + }() + return nctx + } + loop, err := netpoll.NewEventLoop(nil, + netpoll.WithOnConnect(onConnect), + netpoll.WithReadTimeout(10*time.Second), + ) + test.Assert(t, err == nil, err) + go func() { + if err := loop.Serve(ln); err != nil { + t.Logf("server failed, err: %v", err) + } + }() + test.WaitServerStart(addr) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + if err := loop.Shutdown(ctx); err != nil { + t.Logf("netpoll eventloop shutdown failed, err: %v", err) + } + }() + + cp, err := NewClientProvider(streamingServiceInfo) + test.Assert(t, err == nil, err) + cctx := context.Background() + cfg := rpcinfo.NewRPCConfig() + cfg.(rpcinfo.MutableRPCConfig).SetStreamRecvTimeout(10 * time.Second) + method := "TriggerStreamErr" + ri := rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(streamingServiceInfo.ServiceName, method, nAddr, nil), + rpcinfo.NewEndpointInfo(streamingServiceInfo.ServiceName, method, nAddr, nil), + rpcinfo.NewInvocation(streamingServiceInfo.ServiceName, method), + cfg, + rpcinfo.NewRPCStats(), + ) + + t.Run("Illegal Frame", func(t *testing.T) { + t.Run("Non-streaming Frame", func(t *testing.T) { + nctx := metainfo.WithValue(cctx, testTypeKey, testTypeIllegalFrame) + cs, err := cp.NewStream(nctx, ri) + test.Assert(t, err == nil, err) + rawcs := cs.(*clientStream) + err = rawcs.RecvMsg(nctx, nil) + test.Assert(t, errors.Is(err, terrors.ErrIllegalFrame), err) + err = rawcs.SendMsg(nctx, nil) + test.Assert(t, errors.Is(err, terrors.ErrIllegalFrame), err) + }) + }) + + t.Run("Illegal Header Frame", func(t *testing.T) { + t.Run("Receive multiple header", func(t *testing.T) { + nctx := metainfo.WithValue(cctx, testTypeKey, testTypeUnexpectedHeaderFrame) + cs, err := cp.NewStream(nctx, ri) + test.Assert(t, err == nil, err) + rawcs := cs.(*clientStream) + err = rawcs.RecvMsg(nctx, nil) + test.Assert(t, errors.Is(err, terrors.ErrUnexpectedHeader), err) + err = rawcs.SendMsg(nctx, nil) + test.Assert(t, errors.Is(err, terrors.ErrUnexpectedHeader), err) + }) + t.Run("Receive multiple trailer", func(t *testing.T) { + nctx := metainfo.WithValue(cctx, testTypeKey, testTypeUnexpectedTrailerFrame) + cs, err := cp.NewStream(nctx, ri) + test.Assert(t, err == nil, err) + rawcs := cs.(*clientStream) + err = rawcs.RecvMsg(nctx, nil) + test.Assert(t, errors.Is(err, io.EOF), err) + // wait for second trailer frame + time.Sleep(50 * time.Millisecond) + err = rawcs.SendMsg(nctx, nil) + test.Assert(t, errors.Is(err, io.EOF), err) + }) + }) + + t.Run("Trailer Frame", func(t *testing.T) { + t.Run("Illegal BizErr", func(t *testing.T) { + nctx := metainfo.WithValue(cctx, testTypeKey, testTypeIllegalBizErr) + cs, err := cp.NewStream(nctx, ri) + test.Assert(t, err == nil, err) + rawcs := cs.(*clientStream) + err = rawcs.RecvMsg(nctx, nil) + test.Assert(t, errors.Is(err, terrors.ErrIllegalBizErr), err) + err = rawcs.SendMsg(nctx, nil) + test.Assert(t, errors.Is(err, terrors.ErrIllegalBizErr), err) + }) + t.Run("Application Exception", func(t *testing.T) { + nctx := metainfo.WithValue(cctx, testTypeKey, testTypeApplicationException) + cs, err := cp.NewStream(nctx, ri) + test.Assert(t, err == nil, err) + rawcs := cs.(*clientStream) + err = rawcs.RecvMsg(nctx, nil) + test.Assert(t, errors.Is(err, terrors.ErrApplicationException), err) + }) + }) +} diff --git a/pkg/streamx/provider/ttstream/errors/errors.go b/pkg/streamx/provider/ttstream/errors/errors.go new file mode 100644 index 0000000000..4eac75f73e --- /dev/null +++ b/pkg/streamx/provider/ttstream/errors/errors.go @@ -0,0 +1,49 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package errors + +import "errors" + +var ( + ErrUnexpectedHeader = &errType{message: "unexpected header frame"} + ErrUnexpectedTrailer = &errType{message: "unexpected trailer frame"} + ErrApplicationException = &errType{message: "application exception"} + ErrIllegalBizErr = &errType{message: "illegal bizErr"} + ErrIllegalFrame = &errType{message: "illegal frame"} + ErrTransport = &errType{message: "transport is closing"} +) + +type errType struct { + message string + basic error + cause error +} + +func (e *errType) WithCause(err error) error { + return &errType{basic: e, cause: err} +} + +func (e *errType) Error() string { + if e.cause == nil { + return e.message + } + return "[" + e.message + "] " + e.cause.Error() +} + +func (e *errType) Is(target error) bool { + return target == e || target == e.basic || errors.Is(e.cause, target) +} diff --git a/pkg/streamx/provider/ttstream/frame.go b/pkg/streamx/provider/ttstream/frame.go index 77b34426a3..24ca51faf9 100644 --- a/pkg/streamx/provider/ttstream/frame.go +++ b/pkg/streamx/provider/ttstream/frame.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/gopkg/protocol/ttheader" "github.com/cloudwego/kitex/pkg/streamx" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" ) const ( @@ -100,7 +101,7 @@ func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr *Frame) (err erro totalLenField, err := ttheader.Encode(ctx, param, writer) if err != nil { - return err + return terrors.ErrIllegalFrame.WithCause(err) } if len(fr.payload) > 0 { if nw, ok := writer.(gopkgthrift.NocopyWriter); ok { @@ -109,7 +110,7 @@ func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr *Frame) (err erro _, err = writer.WriteBinary(fr.payload) } if err != nil { - return err + return terrors.ErrTransport.WithCause(err) } } written = writer.WrittenLen() - written @@ -121,11 +122,10 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro var dp ttheader.DecodeParam dp, err = ttheader.Decode(ctx, reader) if err != nil { - return + return nil, terrors.ErrIllegalFrame.WithCause(err) } if dp.Flags&ttheader.HeaderFlagsStreaming == 0 { - err = fmt.Errorf("unexpected header flags: %d", dp.Flags) - return + return nil, terrors.ErrIllegalFrame.WithCause(fmt.Errorf("unexpected header flags: %d", dp.Flags)) } var ftype int32 @@ -143,8 +143,7 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro ftype = trailerFrameType ftrailer = dp.StrInfo default: - err = fmt.Errorf("unexpected frame type: %v", dp.IntInfo[ttheader.FrameType]) - return + return nil, terrors.ErrIllegalFrame.WithCause(fmt.Errorf("unexpected frame type: %v", dp.IntInfo[ttheader.FrameType])) } fmethod := dp.IntInfo[ttheader.ToMethod] fsid := dp.SeqID @@ -156,7 +155,7 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro _, err = reader.ReadBinary(fpayload) // copy read _ = reader.Release(err) if err != nil { - return + return nil, terrors.ErrTransport.WithCause(err) } } else { _ = reader.Release(nil) diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go index 6313d8884f..9fb2a55f90 100644 --- a/pkg/streamx/provider/ttstream/server_provider.go +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -68,7 +68,7 @@ func (s serverProvider) OnActive(ctx context.Context, conn net.Conn) (context.Co trans := newTransport(serverTransport, s.sinfo, nconn) _ = nconn.(onDisConnectSetter).SetOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { // server only close transport when peer connection closed - _ = trans.Close() + _ = trans.Close(nil) }) return context.WithValue(ctx, serverTransCtxKey{}, trans), nil } diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 0dd2e02759..b650e35bc6 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" "github.com/cloudwego/kitex/pkg/transmeta" ) @@ -75,6 +76,8 @@ type stream struct { peerEOF int32 headerSig chan int32 trailerSig chan int32 + sio *streamIO + closedFlag int32 // 1 means stream is closed in exception scenario StreamMeta metaHandler MetaFrameHandler @@ -96,7 +99,11 @@ func (s *stream) Method() string { return s.method } -func (s *stream) close() { +// close stream in exception scenario +func (s *stream) close(exception error) { + if !atomic.CompareAndSwapInt32(&s.closedFlag, 0, 1) { + return + } select { case s.headerSig <- streamSigInactive: default: @@ -105,6 +112,20 @@ func (s *stream) close() { case s.trailerSig <- streamSigInactive: default: } + s.sio.close(exception) + s.trans.streamDelete(s.sid) +} + +func (s *stream) isClosed() bool { + return atomic.LoadInt32(&s.closedFlag) == 1 +} + +func (s *stream) isSendFinished() bool { + return atomic.LoadInt32(&s.selfEOF) == 1 +} + +func (s *stream) cancel() { + s.sio.cancel() } func (s *stream) setMetaFrameHandler(h MetaFrameHandler) { @@ -119,11 +140,14 @@ func (s *stream) readMetaFrame(intHeader IntHeader, header streamx.Header, paylo } func (s *stream) readHeader(hd streamx.Header) (err error) { + if s.header != nil { + return terrors.ErrUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) + } s.header = hd select { case s.headerSig <- streamSigActive: default: - return fmt.Errorf("stream[%d] already set header", s.sid) + return terrors.ErrUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) } klog.Debugf("stream[%s] read header: %v", s.method, hd) return nil @@ -146,7 +170,7 @@ func (s *stream) sendHeader() (err error) { if wheader == nil { return fmt.Errorf("stream header already sent") } - err = s.trans.streamSendHeader(s.sid, s.method, wheader) + err = s.trans.streamSendHeader(s, wheader) return err } @@ -154,20 +178,22 @@ func (s *stream) sendHeader() (err error) { // readTrailer by server: unblock recv function and return EOF if no unread frame func (s *stream) readTrailerFrame(fr *Frame) (err error) { if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { - return fmt.Errorf("stream read a unexcept trailer") + return terrors.ErrUnexpectedTrailer.WithCause(fmt.Errorf("content: %v", fr)) } var exception error // when server-side returns non-biz error, it will be wrapped as ApplicationException stored in trailer frame payload if len(fr.payload) > 0 { // exception is type of (*thrift.ApplicationException) - _, _, exception = thrift.UnmarshalFastMsg(fr.payload, nil) + _, _, err = thrift.UnmarshalFastMsg(fr.payload, nil) + exception = terrors.ErrApplicationException.WithCause(err) } else { // when server-side returns biz error, payload is empty and biz error information is stored in trailer frame header bizErr, err := transmeta.ParseBizStatusErr(fr.trailer) if err != nil { - exception = err + exception = terrors.ErrIllegalBizErr.WithCause(err) } else if bizErr != nil { + // bizErr is independent of rpc exception handling exception = bizErr } } @@ -175,7 +201,7 @@ func (s *stream) readTrailerFrame(fr *Frame) (err error) { select { case s.trailerSig <- streamSigActive: default: - return errors.New("already set trailer") + return terrors.ErrUnexpectedTrailer.WithCause(errors.New("already set trailer")) } select { case s.headerSig <- streamSigNone: @@ -207,7 +233,7 @@ func (s *stream) sendTrailer(ctx context.Context, ex tException) (err error) { return fmt.Errorf("stream trailer already sent") } klog.Debugf("transport[%d]-stream[%d] send trailer", s.trans.kind, s.sid) - return s.trans.streamCloseSend(s.sid, s.method, wtrailer, ex) + return s.trans.streamCloseSend(s, wtrailer, ex) } func (s *stream) setRecvTimeout(timeout time.Duration) { @@ -218,7 +244,7 @@ func (s *stream) setRecvTimeout(timeout time.Duration) { } func (s *stream) SendMsg(ctx context.Context, res any) (err error) { - err = s.trans.streamSend(ctx, s.sid, s.method, s.wheader, res) + err = s.trans.streamSend(ctx, s, res) return err } @@ -228,7 +254,7 @@ func (s *stream) RecvMsg(ctx context.Context, req any) error { ctx, cancel = context.WithTimeout(ctx, s.recvTimeout) defer cancel() } - return s.trans.streamRecv(ctx, s.sid, req) + return s.trans.streamRecv(ctx, s, req) } func newClientStream(s *stream) *clientStream { @@ -279,7 +305,6 @@ func (s *serverStream) close(ex tException) error { if err != nil { return err } - s.trans.streamDelete(s.sid) - s.stream.close() + s.stream.close(ex) return nil } diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go index 4e4b21f359..78b84990d8 100644 --- a/pkg/streamx/provider/ttstream/stream_io.go +++ b/pkg/streamx/provider/ttstream/stream_io.go @@ -33,8 +33,6 @@ type streamIOMsg struct { type streamIO struct { ctx context.Context - trigger chan struct{} - stream *stream pipe *container.Pipe[streamIOMsg] cache [1]streamIOMsg exception error // once has exception, the stream should not work normally again @@ -43,11 +41,9 @@ type streamIO struct { closeCallback func(ctx context.Context) } -func newStreamIO(ctx context.Context, s *stream) *streamIO { +func newStreamIO(ctx context.Context) *streamIO { sio := new(streamIO) sio.ctx = ctx - sio.trigger = make(chan struct{}) - sio.stream = s sio.pipe = container.NewPipe[streamIOMsg]() return sio } @@ -103,14 +99,17 @@ func (s *streamIO) closeSend() { } } -func (s *streamIO) close() { - s.pipe.Close() - s.stream.close() - s.runCloseCallback() -} - func (s *streamIO) cancel() { s.pipe.Cancel() - s.stream.close() s.runCloseCallback() } + +func (s *streamIO) close(exception error) { + if exception != nil { + s.input(context.Background(), streamIOMsg{exception: exception}) + } + s.pipe.Close() + if flag := atomic.AddInt32(&s.eofFlag, 2); (flag == 2 || flag == 3) && s.closeCallback != nil { + s.runCloseCallback() + } +} diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index 97a58c7797..c5e08be6c2 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -53,10 +53,11 @@ func isIgnoreError(err error) bool { } type transport struct { - kind int32 - sinfo *serviceinfo.ServiceInfo - conn netpoll.Connection - streams sync.Map // key=streamID val=streamIO + kind int32 + sinfo *serviceinfo.ServiceInfo + conn netpoll.Connection + // transport should operate directly on stream + streams sync.Map // key=streamID val=stream scache []*stream // size is streamCacheSize spipe *container.Pipe[*stream] // in-coming stream pipe fpipe *container.Pipe[*Frame] // out-coming frame pipe @@ -90,7 +91,7 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne } // if connection is closed by peer, loop read should return ErrConnClosed error, // so we should close transport here - _ = t.Close() + _ = t.Close(err) } }() go func() { @@ -102,7 +103,7 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne if !isIgnoreError(err) { klog.Warnf("transport[%d] loop write err: %v", t.kind, err) } - _ = t.Close() + _ = t.Close(err) } }() return t @@ -111,7 +112,9 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne // Close will close transport and destroy all resource and goroutines // server close transport when connection is disconnected // client close transport when transPool discard the transport -func (t *transport) Close() (err error) { +// when an exception is encountered and the transport needs to be closed, +// the exception is not nil and the currently surviving streams are aware of this exception. +func (t *transport) Close(exception error) (err error) { if !atomic.CompareAndSwapInt32(&t.closedFlag, 0, 1) { return nil } @@ -124,9 +127,8 @@ func (t *transport) Close() (err error) { t.spipe.Close() t.fpipe.Close() t.streams.Range(func(key, value any) bool { - sio := value.(*streamIO) - sio.close() - t.streamDelete(sio.stream.sid) + s := value.(*stream) + s.close(exception) return true }) return err @@ -141,24 +143,22 @@ func (t *transport) IsActive() bool { return atomic.LoadInt32(&t.closedFlag) == 0 && t.conn.IsActive() } -func (t *transport) storeStreamIO(ctx context.Context, s *stream) *streamIO { - sio := newStreamIO(ctx, s) - t.streams.Store(s.sid, sio) - +func (t *transport) storeStream(ctx context.Context, s *stream) { + s.sio = newStreamIO(ctx) copts := streamxcallopt.GetCallOptionsFromCtx(ctx) if copts != nil && copts.StreamCloseCallback != nil { - sio.closeCallback = copts.StreamCloseCallback + s.sio.closeCallback = copts.StreamCloseCallback } - return sio + t.streams.Store(s.sid, s) } -func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { +func (t *transport) loadStream(sid int32) (s *stream, ok bool) { val, ok := t.streams.Load(sid) if !ok { - return sio, false + return s, false } - sio = val.(*streamIO) - return sio, true + s = val.(*stream) + return s, true } func (t *transport) readFrame(reader bufiox.Reader) error { @@ -171,9 +171,9 @@ func (t *transport) readFrame(reader bufiox.Reader) error { switch fr.typ { case metaFrameType: - sio, ok := t.loadStreamIO(fr.sid) + s, ok := t.loadStream(fr.sid) if ok { - err = sio.stream.readMetaFrame(fr.meta, fr.header, fr.payload) + err = s.readMetaFrame(fr.meta, fr.header, fr.payload) } else { klog.Errorf("transport[%d] read a unknown stream meta: sid=%d", t.kind, fr.sid) } @@ -183,13 +183,15 @@ func (t *transport) readFrame(reader bufiox.Reader) error { // Header Frame: server recv a new stream smode := t.sinfo.MethodInfo(fr.method).StreamingMode() s := newStream(t, smode, fr.streamFrame) - t.storeStreamIO(context.Background(), s) + t.storeStream(context.Background(), s) err = t.spipe.Write(context.Background(), s) case clientTransport: // Header Frame: client recv header - sio, ok := t.loadStreamIO(fr.sid) + s, ok := t.loadStream(fr.sid) if ok { - err = sio.stream.readHeader(fr.header) + if sErr := s.readHeader(fr.header); sErr != nil { + s.close(sErr) + } } else { klog.Errorf("transport[%d] read a unknown stream header: sid=%d header=%v", t.kind, fr.sid, fr.header) @@ -197,17 +199,19 @@ func (t *transport) readFrame(reader bufiox.Reader) error { } case dataFrameType: // Data Frame: decode and distribute data - sio, ok := t.loadStreamIO(fr.sid) + s, ok := t.loadStream(fr.sid) if ok { - sio.input(context.Background(), streamIOMsg{payload: fr.payload}) + s.sio.input(context.Background(), streamIOMsg{payload: fr.payload}) } else { klog.Errorf("transport[%d] read a unknown stream data: sid=%d", t.kind, fr.sid) } case trailerFrameType: // Trailer Frame: recv trailer, Close read direction - sio, ok := t.loadStreamIO(fr.sid) + s, ok := t.loadStream(fr.sid) if ok { - err = sio.stream.readTrailerFrame(fr) + if sErr := s.readTrailerFrame(fr); sErr != nil { + s.close(sErr) + } } else { // client recv an unknown trailer is in exception, // because the client stream may already be GCed, @@ -223,6 +227,7 @@ func (t *transport) loopRead() error { reader := newReaderBuffer(t.conn.Reader()) for { err := t.readFrame(reader) + // judge whether it is connection-level error // read frame return an un-recovered error, so we should close the transport if err != nil { return err @@ -267,9 +272,15 @@ func (t *transport) writeFrame(sframe streamFrame, meta IntHeader, ftype int32, return t.fpipe.Write(context.Background(), frame) } -func (t *transport) streamSend(ctx context.Context, sid int32, method string, wheader streamx.Header, res any) (err error) { - if len(wheader) > 0 { - err = t.streamSendHeader(sid, method, wheader) +func (t *transport) streamSend(ctx context.Context, s *stream, res any) (err error) { + if s.isClosed() { + return s.sio.exception + } + if s.isSendFinished() { + return io.EOF + } + if len(s.wheader) > 0 { + err = t.streamSendHeader(s, s.wheader) if err != nil { return err } @@ -286,46 +297,38 @@ func (t *transport) streamSend(ctx context.Context, sid int32, method string, wh } } return t.writeFrame( - streamFrame{sid: sid, method: method}, + streamFrame{sid: s.sid, method: s.method}, nil, dataFrameType, payload, ) } -func (t *transport) streamSendHeader(sid int32, method string, header streamx.Header) (err error) { +func (t *transport) streamSendHeader(s *stream, header streamx.Header) (err error) { return t.writeFrame( - streamFrame{sid: sid, method: method, header: header}, + streamFrame{sid: s.sid, method: s.method, header: header}, nil, headerFrameType, nil) } -func (t *transport) streamCloseSend(sid int32, method string, trailer streamx.Trailer, ex tException) (err error) { +func (t *transport) streamCloseSend(s *stream, trailer streamx.Trailer, exception tException) (err error) { var payload []byte - if ex != nil { - payload, err = EncodeException(context.Background(), method, sid, ex) + if exception != nil { + payload, err = EncodeException(context.Background(), s.method, s.sid, exception) if err != nil { return err } } err = t.writeFrame( - streamFrame{sid: sid, method: method, trailer: trailer}, + streamFrame{sid: s.sid, method: s.method, trailer: trailer}, nil, trailerFrameType, payload, ) if err != nil { return err } - sio, ok := t.loadStreamIO(sid) - if !ok { - return nil - } - sio.closeSend() + s.sio.closeSend() return nil } -func (t *transport) streamRecv(ctx context.Context, sid int32, data any) (err error) { - sio, ok := t.loadStreamIO(sid) - if !ok { - return io.EOF - } - msg, err := sio.output(ctx) +func (t *transport) streamRecv(ctx context.Context, s *stream, data any) (err error) { + msg, err := s.sio.output(ctx) if err != nil { return err } @@ -344,14 +347,11 @@ func (t *transport) streamRecv(ctx context.Context, sid int32, data any) (err er } func (t *transport) streamCloseRecv(s *stream, exception error) error { - sio, ok := t.loadStreamIO(s.sid) - if !ok { - return fmt.Errorf("stream not found in stream map: sid=%d", s.sid) - } if exception != nil { - sio.input(context.Background(), streamIOMsg{exception: exception}) + s.close(exception) + } else { + s.sio.closeRecv() } - sio.closeRecv() return nil } @@ -370,12 +370,12 @@ func (t *transport) IsStreaming() bool { var clientStreamID int32 -// newStreamIO create new stream on current connection +// newStream create new stream on current connection // it's typically used by client side -// newStreamIO is concurrency safe -func (t *transport) newStreamIO( +// newStream is concurrency safe +func (t *transport) newStream( ctx context.Context, method string, intHeader IntHeader, strHeader streamx.Header, -) (*streamIO, error) { +) (*stream, error) { if t.kind != clientTransport { return nil, fmt.Errorf("transport already be used as other kind") } @@ -391,9 +391,9 @@ func (t *transport) newStreamIO( return nil, err } s := newStream(t, smode, streamFrame{sid: sid, method: method}) - sio := t.storeStreamIO(ctx, s) + t.storeStream(ctx, s) atomic.AddInt32(&t.streamingFlag, 1) - return sio, nil + return s, nil } // readStream wait for a new incoming stream on current connection diff --git a/pkg/streamx/streamx_common_test.go b/pkg/streamx/streamx_common_test.go index 7a154076d4..48907feea1 100644 --- a/pkg/streamx/streamx_common_test.go +++ b/pkg/streamx/streamx_common_test.go @@ -18,14 +18,14 @@ package streamx_test import ( "context" + "errors" "testing" "github.com/bytedance/gopkg/cloud/metainfo" - "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" - "github.com/cloudwego/kitex/pkg/remote" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" ) const ( @@ -80,10 +80,7 @@ func validateMetadata(ctx context.Context) bool { } func assertNormalErr(t *testing.T, err error) { - ex, ok := err.(*thrift.ApplicationException) - test.Assert(t, ok, err) - test.Assert(t, ex.TypeID() == remote.InternalError, ex.TypeID()) - test.Assert(t, ex.Msg() == "biz error: "+normalErrMsg, ex.Msg()) + test.Assert(t, errors.Is(err, terrors.ErrApplicationException), err) } func assertBizErr(t *testing.T, err error) { From 74c33ecfa3eb081bf69b146e08db3d4d05f2924b Mon Sep 17 00:00:00 2001 From: Joway Date: Fri, 1 Nov 2024 17:51:52 +0800 Subject: [PATCH 12/34] refactor: stream transport layer (#1600) --- client/client_streamx.go | 2 +- .../streamxcallopt/call_option.go | 6 +- go.mod | 2 +- go.sum | 2 + pkg/remote/trans/streamx/server_handler.go | 4 +- pkg/streamx/client_options.go | 2 + pkg/streamx/client_provider.go | 1 + pkg/streamx/client_provider_internal.go | 1 + ...nt_option.go => client_provider_option.go} | 23 +- .../provider/ttstream/client_provier.go | 34 +- .../ttstream/client_trans_pool_longconn.go | 2 +- .../ttstream/client_trans_pool_mux.go | 110 ------ .../ttstream/client_trans_pool_muxconn.go | 173 ++++++++ .../ttstream/client_trans_pool_shortconn.go | 6 +- .../ttstream/container/object_pool.go | 11 +- .../ttstream/container/object_pool_test.go | 70 ++++ .../provider/ttstream/error_scenario_test.go | 246 ------------ .../provider/ttstream/errors/errors.go | 7 +- .../errors_test.go} | 21 +- pkg/streamx/provider/ttstream/frame.go | 33 +- .../provider/ttstream/frame_handler.go | 60 ++- pkg/streamx/provider/ttstream/frame_test.go | 10 +- .../provider/ttstream/meta_frame_handler.go | 70 ---- pkg/streamx/provider/ttstream/metadata.go | 6 +- pkg/streamx/provider/ttstream/mock_test.go | 14 +- .../provider/ttstream/server_provider.go | 48 ++- .../ttstream/server_provider_option.go | 34 ++ pkg/streamx/provider/ttstream/stream.go | 368 ++++++++++-------- ...eam_header_trailer.go => stream_client.go} | 45 ++- pkg/streamx/provider/ttstream/stream_io.go | 115 ------ .../provider/ttstream/stream_reader.go | 87 +++++ .../provider/ttstream/stream_reader_test.go | 75 ++++ .../provider/ttstream/stream_server.go | 73 ++++ pkg/streamx/provider/ttstream/stream_test.go | 56 +++ pkg/streamx/provider/ttstream/test_utils.go | 55 +++ pkg/streamx/provider/ttstream/transport.go | 268 ++++--------- .../provider/ttstream/transport_test.go | 230 +++++++++++ pkg/streamx/server_provider.go | 4 +- pkg/streamx/server_provider_internal.go | 1 + pkg/streamx/stream.go | 46 +-- pkg/streamx/streamx_gen_service_test.go | 2 +- pkg/streamx/streamx_user_test.go | 209 ++++++---- server/streamxserver/server_gen.go | 2 +- 43 files changed, 1557 insertions(+), 1077 deletions(-) rename pkg/streamx/provider/ttstream/{client_option.go => client_provider_option.go} (54%) delete mode 100644 pkg/streamx/provider/ttstream/client_trans_pool_mux.go create mode 100644 pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go create mode 100644 pkg/streamx/provider/ttstream/container/object_pool_test.go delete mode 100644 pkg/streamx/provider/ttstream/error_scenario_test.go rename pkg/streamx/provider/ttstream/{server_option.go => errors/errors_test.go} (58%) delete mode 100644 pkg/streamx/provider/ttstream/meta_frame_handler.go create mode 100644 pkg/streamx/provider/ttstream/server_provider_option.go rename pkg/streamx/provider/ttstream/{stream_header_trailer.go => stream_client.go} (63%) delete mode 100644 pkg/streamx/provider/ttstream/stream_io.go create mode 100644 pkg/streamx/provider/ttstream/stream_reader.go create mode 100644 pkg/streamx/provider/ttstream/stream_reader_test.go create mode 100644 pkg/streamx/provider/ttstream/stream_server.go create mode 100644 pkg/streamx/provider/ttstream/stream_test.go create mode 100644 pkg/streamx/provider/ttstream/test_utils.go create mode 100644 pkg/streamx/provider/ttstream/transport_test.go diff --git a/client/client_streamx.go b/client/client_streamx.go index 73ca7264d3..a0eb34829f 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -56,7 +56,7 @@ func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOp // tracing ctx = kc.opt.TracerCtl.DoStart(ctx, ri) ctx, copts := streamxcallopt.NewCtxWithCallOptions(ctx) - callOptions = append(callOptions, streamxcallopt.WithStreamCloseCallback(func(ctx context.Context) { + callOptions = append(callOptions, streamxcallopt.WithStreamCloseCallback(func() { kc.opt.TracerCtl.DoFinish(ctx, ri, err) })) copts.Apply(callOptions) diff --git a/client/streamxclient/streamxcallopt/call_option.go b/client/streamxclient/streamxcallopt/call_option.go index 5e2325ab0f..bbd6a157cc 100644 --- a/client/streamxclient/streamxcallopt/call_option.go +++ b/client/streamxclient/streamxcallopt/call_option.go @@ -20,10 +20,10 @@ import ( "context" ) -type StreamCloseCallback func(ctx context.Context) +type StreamCloseCallback func() type CallOptions struct { - StreamCloseCallback StreamCloseCallback + StreamCloseCallback []StreamCloseCallback } type CallOption struct { @@ -59,6 +59,6 @@ func (copts *CallOptions) Apply(opts []CallOption) { func WithStreamCloseCallback(callback StreamCloseCallback) CallOption { return CallOption{f: func(o *CallOptions) { - o.StreamCloseCallback = callback + o.StreamCloseCallback = append(o.StreamCloseCallback, callback) }} } diff --git a/go.mod b/go.mod index c996b007a5..bdb558128e 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/cloudwego/frugal v0.2.0 github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4 github.com/cloudwego/localsession v0.1.1 - github.com/cloudwego/netpoll v0.6.4 + github.com/cloudwego/netpoll v0.6.5-0.20240911104114-8a1f5597a920 github.com/cloudwego/runtimex v0.1.0 github.com/cloudwego/thriftgo v0.3.18 github.com/golang/mock v1.6.0 diff --git a/go.sum b/go.sum index cf1260511b..77d85246ad 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/cloudwego/localsession v0.1.1 h1:tbK7laDVrYfFDXoBXo4uCGMAxU4qmz2dDm8d github.com/cloudwego/localsession v0.1.1/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= github.com/cloudwego/netpoll v0.6.4 h1:z/dA4sOTUQof6zZIO4QNnLBXsDFFFEos9OOGloR6kno= github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= +github.com/cloudwego/netpoll v0.6.5-0.20240911104114-8a1f5597a920 h1:WT7vsDDb+ammyB7XLmNSS4vKGpPvM2JDl6h34Jj7mY4= +github.com/cloudwego/netpoll v0.6.5-0.20240911104114-8a1f5597a920/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= github.com/cloudwego/runtimex v0.1.0/go.mod h1:23vL/HGV0W8nSCHbe084AgEBdDV4rvXenEUMnUNvUd8= github.com/cloudwego/thriftgo v0.3.18 h1:gnr1vz7G3RbwwCK9AMKHZf63VYGa7ene6WbI9VrBJSw= diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index 99a0ae069c..b4408bad43 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -99,11 +99,13 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) var wg sync.WaitGroup defer func() { wg.Wait() - klog.CtxErrorf(ctx, "KITEX: stream OnRead return: err=%v", err) _, nerr := t.provider.OnInactive(ctx, conn) if err == nil && nerr != nil { err = nerr } + if err != nil { + klog.CtxErrorf(ctx, "KITEX: stream OnRead return: err=%v", err) + } }() // connection level goroutine for { diff --git a/pkg/streamx/client_options.go b/pkg/streamx/client_options.go index 935ad0d61b..9bded6a252 100644 --- a/pkg/streamx/client_options.go +++ b/pkg/streamx/client_options.go @@ -23,8 +23,10 @@ import ( "github.com/cloudwego/kitex/pkg/stats" ) +// EventHandler define stats event handler type EventHandler func(ctx context.Context, evt stats.Event, err error) +// ClientOptions define the client options type ClientOptions struct { RecvTimeout time.Duration StreamMWs []StreamMiddleware diff --git a/pkg/streamx/client_provider.go b/pkg/streamx/client_provider.go index 1ed92c5ae6..b6da978caf 100644 --- a/pkg/streamx/client_provider.go +++ b/pkg/streamx/client_provider.go @@ -39,6 +39,7 @@ res := stream.Recv(...) => clientProvider.Stream.Recv(...) */ +// ClientProvider define client provider API type ClientProvider interface { // NewStream create a stream based on rpcinfo and callOptions NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (ClientStream, error) diff --git a/pkg/streamx/client_provider_internal.go b/pkg/streamx/client_provider_internal.go index ba3f4da6f0..7f80aa15be 100644 --- a/pkg/streamx/client_provider_internal.go +++ b/pkg/streamx/client_provider_internal.go @@ -22,6 +22,7 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" ) +// NewClientProvider wrap specific client provider func NewClientProvider(cs ClientProvider) ClientProvider { return internalClientProvider{ClientProvider: cs} } diff --git a/pkg/streamx/provider/ttstream/client_option.go b/pkg/streamx/provider/ttstream/client_provider_option.go similarity index 54% rename from pkg/streamx/provider/ttstream/client_option.go rename to pkg/streamx/provider/ttstream/client_provider_option.go index d27fb6c3fe..85b5bfac58 100644 --- a/pkg/streamx/provider/ttstream/client_option.go +++ b/pkg/streamx/provider/ttstream/client_provider_option.go @@ -16,34 +16,47 @@ package ttstream +// ClientProviderOption define client provider options type ClientProviderOption func(cp *clientProvider) -func WithClientMetaHandler(metaHandler MetaFrameHandler) ClientProviderOption { +// WithClientMetaFrameHandler register TTHeader Streaming meta frame handler +func WithClientMetaFrameHandler(handler MetaFrameHandler) ClientProviderOption { return func(cp *clientProvider) { - cp.metaHandler = metaHandler + cp.metaHandler = handler } } -func WithClientHeaderHandler(handler HeaderFrameHandler) ClientProviderOption { +// WithClientHeaderFrameHandler register TTHeader Streaming header frame handler +func WithClientHeaderFrameHandler(handler HeaderFrameWriteHandler) ClientProviderOption { return func(cp *clientProvider) { cp.headerHandler = handler } } +// WithClientDisableCancelingTransmit disable canceling transmit from upstream +func WithClientDisableCancelingTransmit() ClientProviderOption { + return func(cp *clientProvider) { + cp.disableCancelingTransmit = true + } +} + +// WithClientLongConnPool using long connection pool for client func WithClientLongConnPool(config LongConnConfig) ClientProviderOption { return func(cp *clientProvider) { cp.transPool = newLongConnTransPool(config) } } +// WithClientShortConnPool using short connection pool for client func WithClientShortConnPool() ClientProviderOption { return func(cp *clientProvider) { cp.transPool = newShortConnTransPool() } } -func WithClientMuxConnPool() ClientProviderOption { +// WithClientMuxConnPool using mux connection pool for client +func WithClientMuxConnPool(config MuxConnConfig) ClientProviderOption { return func(cp *clientProvider) { - cp.transPool = newMuxTransPool() + cp.transPool = newMuxConnTransPool(config) } } diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index ba78a1931f..f7556c7068 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -32,10 +32,11 @@ import ( var _ streamx.ClientProvider = (*clientProvider)(nil) +// NewClientProvider return a client provider func NewClientProvider(sinfo *serviceinfo.ServiceInfo, opts ...ClientProviderOption) (streamx.ClientProvider, error) { cp := new(clientProvider) cp.sinfo = sinfo - cp.transPool = newMuxTransPool() + cp.transPool = newMuxConnTransPool(DefaultMuxConnConfig) for _, opt := range opts { opt(cp) } @@ -46,9 +47,13 @@ type clientProvider struct { transPool transPool sinfo *serviceinfo.ServiceInfo metaHandler MetaFrameHandler - headerHandler HeaderFrameHandler + headerHandler HeaderFrameWriteHandler + + // options + disableCancelingTransmit bool } +// NewStream return a client stream func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (streamx.ClientStream, error) { rconfig := ri.Config() invocation := ri.Invocation() @@ -62,7 +67,7 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (stre var intHeader IntHeader var err error if c.headerHandler != nil { - intHeader, strHeader, err = c.headerHandler.OnStream(ctx) + intHeader, strHeader, err = c.headerHandler.OnWriteStream(ctx) if err != nil { return nil, err } @@ -78,30 +83,21 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (stre return nil, err } - s, err := trans.newStream(ctx, method, intHeader, strHeader) + s, err := trans.WriteStream(ctx, method, intHeader, strHeader) if err != nil { return nil, err } s.setRecvTimeout(rconfig.StreamRecvTimeout()) - // only client can set meta frame handler s.setMetaFrameHandler(c.metaHandler) // if ctx from server side, we should cancel the stream when server handler already returned - // TODO: this canceling transmit should be configurable - ktx.RegisterCancelCallback(ctx, func() { - s.cancel() - }) + if !c.disableCancelingTransmit { + ktx.RegisterCancelCallback(ctx, func() { + _ = s.cancel() + }) + } cs := newClientStream(s) - runtime.SetFinalizer(cs, func(cstream *clientStream) { - // it's safe to call CloseSend twice - // we do CloseSend here to ensure stream can be closed normally - _ = cstream.CloseSend(ctx) - - s.close(nil) - if trans.IsActive() { - c.transPool.Put(trans) - } - }) + runtime.SetFinalizer(cs, clientStreamFinalizer) return cs, err } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go index bae4f6ec6e..87c9109c79 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go @@ -63,7 +63,7 @@ func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr st if err != nil { return nil, err } - trans = newTransport(clientTransport, sinfo, conn) + trans = newTransport(clientTransport, sinfo, conn, c) // create new transport return trans, nil } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go deleted file mode 100644 index b16acab47e..0000000000 --- a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream - -import ( - "errors" - "runtime" - "sync" - "sync/atomic" - "time" - - "github.com/cloudwego/netpoll" - "golang.org/x/sync/singleflight" - - "github.com/cloudwego/kitex/pkg/serviceinfo" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" -) - -var _ transPool = (*muxTransPool)(nil) - -type muxTransList struct { - L sync.RWMutex - size int - cursor uint32 - transports []*transport -} - -func newMuxTransList(size int) *muxTransList { - tl := new(muxTransList) - tl.size = size - tl.transports = make([]*transport, size) - return tl -} - -func (tl *muxTransList) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { - idx := atomic.AddUint32(&tl.cursor, 1) % uint32(tl.size) - tl.L.RLock() - trans := tl.transports[idx] - tl.L.RUnlock() - if trans != nil && trans.IsActive() { - return trans, nil - } - - conn, err := dialer.DialConnection(network, addr, time.Second) - if err != nil { - return nil, err - } - trans = newTransport(clientTransport, sinfo, conn) - _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { - // peer close - _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("netpoll connection closed"))) - return nil - }) - runtime.SetFinalizer(trans, func(trans *transport) { - // self close when not hold by user - // todo: think about a more ideal error - _ = trans.Close(nil) - }) - tl.L.Lock() - tl.transports[idx] = trans - tl.L.Unlock() - return trans, nil -} - -func newMuxTransPool() transPool { - t := new(muxTransPool) - t.poolSize = runtime.GOMAXPROCS(0) - return t -} - -type muxTransPool struct { - poolSize int - pool sync.Map // addr:*muxTransList - sflight singleflight.Group -} - -func (m *muxTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) { - v, ok := m.pool.Load(addr) - if ok { - return v.(*muxTransList).Get(sinfo, network, addr) - } - - v, err, _ = m.sflight.Do(addr, func() (interface{}, error) { - transList := newMuxTransList(m.poolSize) - m.pool.Store(addr, transList) - return transList, nil - }) - if err != nil { - return nil, err - } - return v.(*muxTransList).Get(sinfo, network, addr) -} - -func (m *muxTransPool) Put(trans *transport) { - // do nothing -} diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go new file mode 100644 index 0000000000..25082b3c7c --- /dev/null +++ b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go @@ -0,0 +1,173 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "errors" + "fmt" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/netpoll" + "golang.org/x/sync/singleflight" + + "github.com/cloudwego/kitex/pkg/serviceinfo" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" +) + +var DefaultMuxConnConfig = MuxConnConfig{ + PoolSize: runtime.GOMAXPROCS(0), + MaxIdleTimeout: time.Minute, +} + +type MuxConnConfig struct { + PoolSize int + MaxIdleTimeout time.Duration +} + +var _ transPool = (*muxConnTransPool)(nil) + +type muxConnTransList struct { + L sync.RWMutex + size int + cursor uint32 + transports []*transport + pool transPool + sf singleflight.Group +} + +func newMuxConnTransList(size int, pool transPool) *muxConnTransList { + tl := new(muxConnTransList) + if size == 0 { + size = runtime.GOMAXPROCS(0) + } + tl.size = size + tl.transports = make([]*transport, size) + tl.pool = pool + return tl +} + +func (tl *muxConnTransList) Close() { + tl.L.Lock() + for i, t := range tl.transports { + _ = t.Close(nil) + tl.transports[i] = nil + } + tl.L.Unlock() +} + +func (tl *muxConnTransList) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { + idx := atomic.AddUint32(&tl.cursor, 1) % uint32(tl.size) + tl.L.RLock() + trans := tl.transports[idx] + tl.L.RUnlock() + if trans != nil && trans.IsActive() { + return trans, nil + } + + v, err, _ := tl.sf.Do(fmt.Sprintf("%d", idx), func() (interface{}, error) { + conn, err := dialer.DialConnection(network, addr, time.Second) + if err != nil { + return nil, err + } + trans := newTransport(clientTransport, sinfo, conn, tl.pool) + _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { + // peer close + _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("connection closed by peer"))) + return nil + }) + runtime.SetFinalizer(trans, func(trans *transport) { + // self close when not hold by user + _ = trans.Close(nil) + }) + tl.L.Lock() + tl.transports[idx] = trans + tl.L.Unlock() + return trans, nil + }) + if err != nil { + return nil, err + } + trans = v.(*transport) + return trans, nil +} + +func newMuxConnTransPool(config MuxConnConfig) transPool { + t := new(muxConnTransPool) + t.config = config + return t +} + +type muxConnTransPool struct { + config MuxConnConfig + pool sync.Map // addr:*muxConnTransList + activity sync.Map // addr:lastActive + sflight singleflight.Group + cleanerOnce sync.Once +} + +func (p *muxConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) { + v, ok := p.pool.Load(addr) + if ok { + return v.(*muxConnTransList).Get(sinfo, network, addr) + } + + v, err, _ = p.sflight.Do(addr, func() (interface{}, error) { + transList := newMuxConnTransList(p.config.PoolSize, p) + p.pool.Store(addr, transList) + return transList, nil + }) + if err != nil { + return nil, err + } + return v.(*muxConnTransList).Get(sinfo, network, addr) +} + +func (p *muxConnTransPool) Put(trans *transport) { + p.activity.Store(trans.conn.RemoteAddr().String(), time.Now()) + p.cleanerOnce.Do(func() { + internal := p.config.MaxIdleTimeout + if internal == 0 { + return + } + go func() { + for { + now := time.Now() + count := 0 + p.activity.Range(func(addr, value interface{}) bool { + count++ + lastActive := value.(time.Time) + if lastActive.IsZero() || now.Sub(lastActive) < p.config.MaxIdleTimeout { + return true + } + v, _ := p.pool.Load(addr) + if v == nil { + return true + } + transList := v.(*muxConnTransList) + p.pool.Delete(addr) + p.activity.Delete(addr) + transList.Close() + return true + }) + time.Sleep(internal) + } + }() + }) +} diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go index a5c0661a0a..71e15e94ed 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go @@ -30,17 +30,17 @@ func newShortConnTransPool() transPool { type shortConnTransPool struct{} -func (c *shortConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { +func (p *shortConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { // create new connection conn, err := dialer.DialConnection(network, addr, time.Second) if err != nil { return nil, err } // create new transport - trans := newTransport(clientTransport, sinfo, conn) + trans := newTransport(clientTransport, sinfo, conn, p) return trans, nil } -func (c *shortConnTransPool) Put(trans *transport) { +func (p *shortConnTransPool) Put(trans *transport) { _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("short connection closed"))) } diff --git a/pkg/streamx/provider/ttstream/container/object_pool.go b/pkg/streamx/provider/ttstream/container/object_pool.go index 8450f357fa..6ac599cbdb 100644 --- a/pkg/streamx/provider/ttstream/container/object_pool.go +++ b/pkg/streamx/provider/ttstream/container/object_pool.go @@ -76,21 +76,12 @@ func (s *ObjectPool) Close() { } func (s *ObjectPool) cleaning() { - cleanInternal := time.Second + cleanInternal := s.idleTimeout for atomic.LoadInt32(&s.closed) == 0 { time.Sleep(cleanInternal) now := time.Now() s.L.Lock() - // update cleanInternal - objSize := 0 - for _, stk := range s.objects { - objSize += stk.Size() - } - cleanInternal = time.Second + time.Duration(objSize)*time.Millisecond*10 - if cleanInternal > time.Second*10 { - cleanInternal = time.Second * 10 - } // clean objects for _, stk := range s.objects { deleted := 0 diff --git a/pkg/streamx/provider/ttstream/container/object_pool_test.go b/pkg/streamx/provider/ttstream/container/object_pool_test.go new file mode 100644 index 0000000000..540d5dd050 --- /dev/null +++ b/pkg/streamx/provider/ttstream/container/object_pool_test.go @@ -0,0 +1,70 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package container + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" +) + +var _ Object = (*testObject)(nil) + +type testObject struct { + closeCallback func() +} + +func (o *testObject) Close(exception error) error { + if o.closeCallback != nil { + o.closeCallback() + } + return nil +} + +func TestObjectPool(t *testing.T) { + op := NewObjectPool(time.Microsecond * 10) + count := 10000 + key := "test" + var wg sync.WaitGroup + for i := 0; i < count; i++ { + o := new(testObject) + o.closeCallback = func() { + wg.Done() + } + wg.Add(1) + op.Push(key, o) + } + wg.Wait() + test.Assert(t, op.objects[key].Size() == 0) + op.Close() + + op = NewObjectPool(time.Second) + var deleted int32 + for i := 0; i < count; i++ { + o := new(testObject) + o.closeCallback = func() { + atomic.AddInt32(&deleted, 1) + } + op.Push(key, o) + } + test.Assert(t, atomic.LoadInt32(&deleted) == 0) + test.Assert(t, op.objects[key].Size() == count) + op.Close() +} diff --git a/pkg/streamx/provider/ttstream/error_scenario_test.go b/pkg/streamx/provider/ttstream/error_scenario_test.go deleted file mode 100644 index 1ed00344f9..0000000000 --- a/pkg/streamx/provider/ttstream/error_scenario_test.go +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream - -import ( - "context" - "encoding/binary" - "errors" - "io" - "net" - "testing" - "time" - - "github.com/bytedance/gopkg/cloud/metainfo" - "github.com/cloudwego/gopkg/bufiox" - "github.com/cloudwego/gopkg/protocol/thrift" - "github.com/cloudwego/gopkg/protocol/ttheader" - "github.com/cloudwego/netpoll" - - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/klog" - "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/streamx" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" -) - -const ( - testTypeKey = "testType" - - testTypeIllegalFrame = "illegalFrame" - testTypeUnexpectedHeaderFrame = "unexpectedHeaderFrame" - testTypeUnexpectedTrailerFrame = "unexpectedTrailerFrame" - testTypeIllegalBizErr = "illegalBizErr" - testTypeApplicationException = "applicationException" -) - -var streamingServiceInfo = &serviceinfo.ServiceInfo{ - ServiceName: "kitex.service.streaming", - Methods: map[string]serviceinfo.MethodInfo{ - "TriggerStreamErr": serviceinfo.NewMethodInfo( - func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return nil - }, - nil, - nil, - false, - serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), - ), - }, - Extra: map[string]interface{}{"streamingFlag": true, "streamx": true}, -} - -type illegalFrameType int32 - -const ( - MissStreamingFlag illegalFrameType = iota -) - -func encodeIllegalFrame(t *testing.T, ctx context.Context, writer bufiox.Writer, fr *Frame, flag illegalFrameType) { - var param ttheader.EncodeParam - written := writer.WrittenLen() - switch flag { - case MissStreamingFlag: - param = ttheader.EncodeParam{ - SeqID: fr.sid, - ProtocolID: ttheader.ProtocolIDThriftStruct, - } - param.IntInfo = fr.meta - if param.IntInfo == nil { - param.IntInfo = make(IntHeader) - } - param.IntInfo[ttheader.FrameType] = frameTypeToString[fr.typ] - param.IntInfo[ttheader.ToMethod] = fr.method - totalLenField, err := ttheader.Encode(ctx, param, writer) - if err != nil { - t.Errorf("ttheader Encode failed, err: %v", err) - } - written = writer.WrittenLen() - written - binary.BigEndian.PutUint32(totalLenField, uint32(written-4)) - } -} - -func TestErrorScenario(t *testing.T) { - klog.SetLevel(klog.LevelDebug) - addr := test.GetLocalAddress() - nAddr, err := net.ResolveTCPAddr("tcp", addr) - test.Assert(t, err == nil, err) - ln, err := netpoll.CreateListener("tcp", addr) - test.Assert(t, err == nil, err) - defer ln.Close() - sp, err := NewServerProvider(streamingServiceInfo) - test.Assert(t, err == nil, err) - onConnect := func(ctx context.Context, conn netpoll.Connection) context.Context { - nctx, err := sp.OnActive(ctx, conn) - test.Assert(t, err == nil, err) - nctx, ss, nerr := sp.OnStream(nctx, conn) - test.Assert(t, nerr == nil, nerr) - go func() { - rawss := ss.(*serverStream) - testType, ok := metainfo.GetValue(nctx, testTypeKey) - test.Assert(t, ok) - switch testType { - case testTypeIllegalFrame: - encodeIllegalFrame(t, nctx, newWriterBuffer(rawss.trans.conn.Writer()), &Frame{ - streamFrame: streamFrame{ - sid: rawss.sid, - }, - typ: headerFrameType, - }, MissStreamingFlag) - rawss.trans.conn.Writer().Flush() - case testTypeUnexpectedHeaderFrame: - hd := streamx.Header{ - "key": "val", - } - rawss.trans.streamSendHeader(rawss.stream, hd) - rawss.trans.streamSendHeader(rawss.stream, hd) - case testTypeUnexpectedTrailerFrame: - rawss.trans.streamCloseSend(rawss.stream, nil, nil) - rawss.trans.streamCloseSend(rawss.stream, nil, nil) - case testTypeIllegalBizErr: - err = rawss.writeTrailer( - streamx.Trailer{ - "biz-status": "1", - "biz-message": "message", - "biz-extra": "invalid extra JSON str", - }, - ) - test.Assert(t, err == nil, err) - err = rawss.sendTrailer(nctx, nil) - test.Assert(t, err == nil, err) - case testTypeApplicationException: - exception := thrift.NewApplicationException(remote.InternalError, "testApplicationException") - err = rawss.sendTrailer(nctx, exception) - test.Assert(t, err == nil, err) - } - }() - return nctx - } - loop, err := netpoll.NewEventLoop(nil, - netpoll.WithOnConnect(onConnect), - netpoll.WithReadTimeout(10*time.Second), - ) - test.Assert(t, err == nil, err) - go func() { - if err := loop.Serve(ln); err != nil { - t.Logf("server failed, err: %v", err) - } - }() - test.WaitServerStart(addr) - defer func() { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - if err := loop.Shutdown(ctx); err != nil { - t.Logf("netpoll eventloop shutdown failed, err: %v", err) - } - }() - - cp, err := NewClientProvider(streamingServiceInfo) - test.Assert(t, err == nil, err) - cctx := context.Background() - cfg := rpcinfo.NewRPCConfig() - cfg.(rpcinfo.MutableRPCConfig).SetStreamRecvTimeout(10 * time.Second) - method := "TriggerStreamErr" - ri := rpcinfo.NewRPCInfo( - rpcinfo.NewEndpointInfo(streamingServiceInfo.ServiceName, method, nAddr, nil), - rpcinfo.NewEndpointInfo(streamingServiceInfo.ServiceName, method, nAddr, nil), - rpcinfo.NewInvocation(streamingServiceInfo.ServiceName, method), - cfg, - rpcinfo.NewRPCStats(), - ) - - t.Run("Illegal Frame", func(t *testing.T) { - t.Run("Non-streaming Frame", func(t *testing.T) { - nctx := metainfo.WithValue(cctx, testTypeKey, testTypeIllegalFrame) - cs, err := cp.NewStream(nctx, ri) - test.Assert(t, err == nil, err) - rawcs := cs.(*clientStream) - err = rawcs.RecvMsg(nctx, nil) - test.Assert(t, errors.Is(err, terrors.ErrIllegalFrame), err) - err = rawcs.SendMsg(nctx, nil) - test.Assert(t, errors.Is(err, terrors.ErrIllegalFrame), err) - }) - }) - - t.Run("Illegal Header Frame", func(t *testing.T) { - t.Run("Receive multiple header", func(t *testing.T) { - nctx := metainfo.WithValue(cctx, testTypeKey, testTypeUnexpectedHeaderFrame) - cs, err := cp.NewStream(nctx, ri) - test.Assert(t, err == nil, err) - rawcs := cs.(*clientStream) - err = rawcs.RecvMsg(nctx, nil) - test.Assert(t, errors.Is(err, terrors.ErrUnexpectedHeader), err) - err = rawcs.SendMsg(nctx, nil) - test.Assert(t, errors.Is(err, terrors.ErrUnexpectedHeader), err) - }) - t.Run("Receive multiple trailer", func(t *testing.T) { - nctx := metainfo.WithValue(cctx, testTypeKey, testTypeUnexpectedTrailerFrame) - cs, err := cp.NewStream(nctx, ri) - test.Assert(t, err == nil, err) - rawcs := cs.(*clientStream) - err = rawcs.RecvMsg(nctx, nil) - test.Assert(t, errors.Is(err, io.EOF), err) - // wait for second trailer frame - time.Sleep(50 * time.Millisecond) - err = rawcs.SendMsg(nctx, nil) - test.Assert(t, errors.Is(err, io.EOF), err) - }) - }) - - t.Run("Trailer Frame", func(t *testing.T) { - t.Run("Illegal BizErr", func(t *testing.T) { - nctx := metainfo.WithValue(cctx, testTypeKey, testTypeIllegalBizErr) - cs, err := cp.NewStream(nctx, ri) - test.Assert(t, err == nil, err) - rawcs := cs.(*clientStream) - err = rawcs.RecvMsg(nctx, nil) - test.Assert(t, errors.Is(err, terrors.ErrIllegalBizErr), err) - err = rawcs.SendMsg(nctx, nil) - test.Assert(t, errors.Is(err, terrors.ErrIllegalBizErr), err) - }) - t.Run("Application Exception", func(t *testing.T) { - nctx := metainfo.WithValue(cctx, testTypeKey, testTypeApplicationException) - cs, err := cp.NewStream(nctx, ri) - test.Assert(t, err == nil, err) - rawcs := cs.(*clientStream) - err = rawcs.RecvMsg(nctx, nil) - test.Assert(t, errors.Is(err, terrors.ErrApplicationException), err) - }) - }) -} diff --git a/pkg/streamx/provider/ttstream/errors/errors.go b/pkg/streamx/provider/ttstream/errors/errors.go index 4eac75f73e..7d75e331ad 100644 --- a/pkg/streamx/provider/ttstream/errors/errors.go +++ b/pkg/streamx/provider/ttstream/errors/errors.go @@ -16,7 +16,9 @@ package errors -import "errors" +import ( + "errors" +) var ( ErrUnexpectedHeader = &errType{message: "unexpected header frame"} @@ -24,6 +26,7 @@ var ( ErrApplicationException = &errType{message: "application exception"} ErrIllegalBizErr = &errType{message: "illegal bizErr"} ErrIllegalFrame = &errType{message: "illegal frame"} + ErrIllegalOperation = &errType{message: "illegal operation"} ErrTransport = &errType{message: "transport is closing"} ) @@ -34,7 +37,7 @@ type errType struct { } func (e *errType) WithCause(err error) error { - return &errType{basic: e, cause: err} + return &errType{message: e.message, basic: e, cause: err} } func (e *errType) Error() string { diff --git a/pkg/streamx/provider/ttstream/server_option.go b/pkg/streamx/provider/ttstream/errors/errors_test.go similarity index 58% rename from pkg/streamx/provider/ttstream/server_option.go rename to pkg/streamx/provider/ttstream/errors/errors_test.go index 2e2b619096..d926d3786d 100644 --- a/pkg/streamx/provider/ttstream/server_option.go +++ b/pkg/streamx/provider/ttstream/errors/errors_test.go @@ -14,12 +14,21 @@ * limitations under the License. */ -package ttstream +package errors -type ServerProviderOption func(pc *serverProvider) +import ( + "errors" + "fmt" + "strings" + "testing" -func WithServerPayloadLimit(limit int) ServerProviderOption { - return func(s *serverProvider) { - s.payloadLimit = limit - } + "github.com/cloudwego/kitex/internal/test" +) + +func TestErrors(t *testing.T) { + causeErr := fmt.Errorf("test1") + newErr := ErrIllegalFrame.WithCause(causeErr) + test.Assert(t, errors.Is(newErr, ErrIllegalFrame), newErr) + test.Assert(t, strings.Contains(newErr.Error(), ErrIllegalFrame.Error())) + test.Assert(t, strings.Contains(newErr.Error(), causeErr.Error())) } diff --git a/pkg/streamx/provider/ttstream/frame.go b/pkg/streamx/provider/ttstream/frame.go index 24ca51faf9..54169c0bd0 100644 --- a/pkg/streamx/provider/ttstream/frame.go +++ b/pkg/streamx/provider/ttstream/frame.go @@ -19,7 +19,9 @@ package ttstream import ( "context" "encoding/binary" + "errors" "fmt" + "io" "sync" "github.com/bytedance/gopkg/lang/mcache" @@ -27,6 +29,8 @@ import ( gopkgthrift "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/streamx" terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" ) @@ -47,14 +51,18 @@ var frameTypeToString = map[int32]string{ var framePool sync.Pool +// Frame define a TTHeader Streaming Frame type Frame struct { streamFrame - meta IntHeader typ int32 payload []byte } -func newFrame(sframe streamFrame, meta IntHeader, typ int32, payload []byte) (fr *Frame) { +func (f *Frame) String() string { + return fmt.Sprintf("[sid=%d ftype=%d fmethod=%s]", f.sid, f.typ, f.method) +} + +func newFrame(sframe streamFrame, typ int32, payload []byte) (fr *Frame) { v := framePool.Get() if v == nil { fr = new(Frame) @@ -62,7 +70,6 @@ func newFrame(sframe streamFrame, meta IntHeader, typ int32, payload []byte) (fr fr = v.(*Frame) } fr.streamFrame = sframe - fr.meta = meta fr.typ = typ fr.payload = payload return fr @@ -70,7 +77,6 @@ func newFrame(sframe streamFrame, meta IntHeader, typ int32, payload []byte) (fr func recycleFrame(frame *Frame) { frame.streamFrame = streamFrame{} - frame.meta = nil frame.typ = 0 frame.payload = nil framePool.Put(frame) @@ -122,6 +128,9 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro var dp ttheader.DecodeParam dp, err = ttheader.Decode(ctx, reader) if err != nil { + if errors.Is(err, io.EOF) { + return nil, err + } return nil, terrors.ErrIllegalFrame.WithCause(err) } if dp.Flags&ttheader.HeaderFlagsStreaming == 0 { @@ -131,6 +140,7 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro var ftype int32 var fheader streamx.Header var ftrailer streamx.Trailer + fmeta := dp.IntInfo switch dp.IntInfo[ttheader.FrameType] { case ttheader.FrameTypeMeta: ftype = metaFrameType @@ -162,8 +172,7 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro } fr = newFrame( - streamFrame{sid: fsid, method: fmethod, header: fheader, trailer: ftrailer}, - dp.IntInfo, + streamFrame{sid: fsid, method: fmethod, meta: fmeta, header: fheader, trailer: ftrailer}, ftype, fpayload, ) return fr, nil @@ -174,10 +183,18 @@ func EncodePayload(ctx context.Context, msg any) ([]byte, error) { return payload, nil } +func EncodeGenericPayload(ctx context.Context, msg any) ([]byte, error) { + return nil, nil +} + func DecodePayload(ctx context.Context, payload []byte, msg any) error { return gopkgthrift.FastUnmarshal(payload, msg.(gopkgthrift.FastCodec)) } -func EncodeException(ctx context.Context, method string, seq int32, ex tException) ([]byte, error) { - return gopkgthrift.MarshalFastMsg(method, gopkgthrift.EXCEPTION, seq, ex.(gopkgthrift.FastCodec)) +func EncodeException(ctx context.Context, method string, seq int32, ex error) ([]byte, error) { + exception, ok := ex.(gopkgthrift.FastCodec) + if !ok { + exception = gopkgthrift.NewApplicationException(remote.InternalError, ex.Error()) + } + return gopkgthrift.MarshalFastMsg(method, gopkgthrift.EXCEPTION, seq, exception) } diff --git a/pkg/streamx/provider/ttstream/frame_handler.go b/pkg/streamx/provider/ttstream/frame_handler.go index 033a64bde2..3581aeb35e 100644 --- a/pkg/streamx/provider/ttstream/frame_handler.go +++ b/pkg/streamx/provider/ttstream/frame_handler.go @@ -18,10 +18,68 @@ package ttstream import ( "context" + "sync" "github.com/cloudwego/kitex/pkg/streamx" ) type HeaderFrameHandler interface { - OnStream(ctx context.Context) (IntHeader, streamx.Header, error) + HeaderFrameReadHandler + HeaderFrameWriteHandler +} + +type HeaderFrameWriteHandler interface { + OnWriteStream(ctx context.Context) (ihd IntHeader, shd StrHeader, err error) +} + +type HeaderFrameReadHandler interface { + OnReadStream(ctx context.Context, ihd IntHeader, shd StrHeader) (context.Context, error) +} + +type MetaFrameHandler interface { + OnMetaFrame(smeta StreamMeta, intHeader IntHeader, header streamx.Header, payload []byte) error +} + +// StreamMeta is a kv storage for meta frame handler +type StreamMeta interface { + Meta() map[string]string + GetMeta(k string) (string, bool) + SetMeta(k, v string, kvs ...string) +} + +var _ StreamMeta = (*streamMeta)(nil) + +func newStreamMeta() StreamMeta { + return &streamMeta{} +} + +type streamMeta struct { + sync sync.RWMutex + data map[string]string +} + +func (s *streamMeta) Meta() map[string]string { + s.sync.RLock() + m := make(map[string]string, len(s.data)) + for k, v := range s.data { + m[k] = v + } + s.sync.RUnlock() + return m +} + +func (s *streamMeta) GetMeta(k string) (string, bool) { + s.sync.RLock() + v, ok := s.data[k] + s.sync.RUnlock() + return v, ok +} + +func (s *streamMeta) SetMeta(k, v string, kvs ...string) { + s.sync.RLock() + s.data[k] = v + for i := 0; i < len(kvs); i += 2 { + s.data[kvs[i]] = kvs[i+1] + } + s.sync.RUnlock() } diff --git a/pkg/streamx/provider/ttstream/frame_test.go b/pkg/streamx/provider/ttstream/frame_test.go index 6bba79a251..10cc68d6ad 100644 --- a/pkg/streamx/provider/ttstream/frame_test.go +++ b/pkg/streamx/provider/ttstream/frame_test.go @@ -34,7 +34,7 @@ func TestFrameCodec(t *testing.T) { sid: 0, method: "method", header: map[string]string{"key": "value"}, - }, nil, headerFrameType, []byte("hello world")) + }, headerFrameType, []byte("hello world")) for i := 0; i < 10; i++ { wframe.sid = int32(i) @@ -53,25 +53,25 @@ func TestFrameCodec(t *testing.T) { } func TestFrameWithoutPayloadCodec(t *testing.T) { - rmsg := new(TestRequest) + rmsg := new(testRequest) rmsg.A = 1 payload, err := EncodePayload(context.Background(), rmsg) test.Assert(t, err == nil, err) - wmsg := new(TestRequest) + wmsg := new(testRequest) err = DecodePayload(context.Background(), payload, wmsg) test.Assert(t, err == nil, err) test.DeepEqual(t, wmsg, rmsg) } func TestPayloadCodec(t *testing.T) { - rmsg := new(TestRequest) + rmsg := new(testRequest) rmsg.A = 1 rmsg.B = "hello world" payload, err := EncodePayload(context.Background(), rmsg) test.Assert(t, err == nil, err) - wmsg := new(TestRequest) + wmsg := new(testRequest) err = DecodePayload(context.Background(), payload, wmsg) test.Assert(t, err == nil, err) test.DeepEqual(t, wmsg, rmsg) diff --git a/pkg/streamx/provider/ttstream/meta_frame_handler.go b/pkg/streamx/provider/ttstream/meta_frame_handler.go deleted file mode 100644 index e5dd6a54dc..0000000000 --- a/pkg/streamx/provider/ttstream/meta_frame_handler.go +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream - -import ( - "sync" - - "github.com/cloudwego/kitex/pkg/streamx" -) - -type StreamMeta interface { - Meta() map[string]string - GetMeta(k string) (string, bool) - SetMeta(k, v string, kvs ...string) -} - -type MetaFrameHandler interface { - OnMetaFrame(smeta StreamMeta, intHeader IntHeader, header streamx.Header, payload []byte) error -} - -var _ StreamMeta = (*streamMeta)(nil) - -func newStreamMeta() StreamMeta { - return &streamMeta{} -} - -type streamMeta struct { - sync sync.RWMutex - data map[string]string -} - -func (s *streamMeta) Meta() map[string]string { - s.sync.RLock() - m := make(map[string]string, len(s.data)) - for k, v := range s.data { - m[k] = v - } - s.sync.RUnlock() - return m -} - -func (s *streamMeta) GetMeta(k string) (string, bool) { - s.sync.RLock() - v, ok := s.data[k] - s.sync.RUnlock() - return v, ok -} - -func (s *streamMeta) SetMeta(k, v string, kvs ...string) { - s.sync.RLock() - s.data[k] = v - for i := 0; i < len(kvs); i += 2 { - s.data[kvs[i]] = kvs[i+1] - } - s.sync.RUnlock() -} diff --git a/pkg/streamx/provider/ttstream/metadata.go b/pkg/streamx/provider/ttstream/metadata.go index fef0b03789..0732a6f7c3 100644 --- a/pkg/streamx/provider/ttstream/metadata.go +++ b/pkg/streamx/provider/ttstream/metadata.go @@ -25,10 +25,14 @@ import ( var ( ErrInvalidStreamKind = errors.New("invalid stream kind") ErrClosedStream = errors.New("stream is closed") + ErrCanceledStream = errors.New("stream is canceled") ) // only for meta frame handler -type IntHeader map[uint16]string +type ( + IntHeader map[uint16]string + StrHeader = streamx.Header +) // ClientStreamMeta cannot send header directly, should send from ctx type ClientStreamMeta interface { diff --git a/pkg/streamx/provider/ttstream/mock_test.go b/pkg/streamx/provider/ttstream/mock_test.go index a2d99b58dc..c0996650d9 100644 --- a/pkg/streamx/provider/ttstream/mock_test.go +++ b/pkg/streamx/provider/ttstream/mock_test.go @@ -24,12 +24,12 @@ import ( kutils "github.com/cloudwego/kitex/pkg/utils" ) -type TestRequest struct { +type testRequest struct { A int32 `thrift:"A,1" frugal:"1,default,i32" json:"A"` B string `thrift:"B,2" frugal:"2,default,string" json:"B"` } -func (p *TestRequest) FastRead(buf []byte) (int, error) { +func (p *testRequest) FastRead(buf []byte) (int, error) { err := json.Unmarshal(buf, p) if err != nil { return 0, err @@ -37,19 +37,19 @@ func (p *TestRequest) FastRead(buf []byte) (int, error) { return len(buf), nil } -func (p *TestRequest) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *testRequest) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { data, _ := json.Marshal(p) copy(buf, data) return len(data) } -func (p *TestRequest) BLength() int { +func (p *testRequest) BLength() int { data, _ := json.Marshal(p) return len(data) } -func (p *TestRequest) DeepCopy(s interface{}) error { - src, ok := s.(*TestRequest) +func (p *testRequest) DeepCopy(s interface{}) error { + src, ok := s.(*testRequest) if !ok { return fmt.Errorf("%T's type not matched %T", s, p) } @@ -63,4 +63,4 @@ func (p *TestRequest) DeepCopy(s interface{}) error { return nil } -type TestResponse = TestRequest +type testResponse = testRequest diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go index 9fb2a55f90..0e032c3b06 100644 --- a/pkg/streamx/provider/ttstream/server_provider.go +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -18,6 +18,7 @@ package ttstream import ( "context" + "fmt" "net" "strconv" @@ -39,6 +40,9 @@ type ( serverStreamCancelCtxKey struct{} ) +var _ streamx.ServerProvider = (*serverProvider)(nil) + +// NewServerProvider return a server provider func NewServerProvider(sinfo *serviceinfo.ServiceInfo, opts ...ServerProviderOption) (streamx.ServerProvider, error) { sp := new(serverProvider) sp.sinfo = sinfo @@ -48,24 +52,29 @@ func NewServerProvider(sinfo *serviceinfo.ServiceInfo, opts ...ServerProviderOpt return sp, nil } -var _ streamx.ServerProvider = (*serverProvider)(nil) - type serverProvider struct { - sinfo *serviceinfo.ServiceInfo - payloadLimit int + sinfo *serviceinfo.ServiceInfo + metaHandler MetaFrameHandler + headerHandler HeaderFrameReadHandler } +// Available sniff the conn if provider can process func (s serverProvider) Available(ctx context.Context, conn net.Conn) bool { - data, err := conn.(netpoll.Connection).Reader().Peek(8) + nconn, ok := conn.(netpoll.Connection) + if !ok { + return false + } + data, err := nconn.Reader().Peek(8) if err != nil { return false } return ttheader.IsStreaming(data) } +// OnActive will be called when a connection accepted func (s serverProvider) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { nconn := conn.(netpoll.Connection) - trans := newTransport(serverTransport, s.sinfo, nconn) + trans := newTransport(serverTransport, s.sinfo, nconn, nil) _ = nconn.(onDisConnectSetter).SetOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { // server only close transport when peer connection closed _ = trans.Close(nil) @@ -85,15 +94,28 @@ func (s serverProvider) OnInactive(ctx context.Context, conn net.Conn) (context. func (s serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Context, streamx.ServerStream, error) { trans, _ := ctx.Value(serverTransCtxKey{}).(*transport) if trans == nil { - return nil, nil, nil + return nil, nil, fmt.Errorf("server transport is nil") } - st, err := trans.readStream(ctx) + + // ReadStream will block until a stream coming or conn return error + st, err := trans.ReadStream(ctx) if err != nil { return nil, nil, err } + st.setMetaFrameHandler(s.metaHandler) + + // headerHandler return a new stream level ctx + if s.headerHandler != nil { + ctx, err = s.headerHandler.OnReadStream(ctx, st.meta, st.header) + if err != nil { + return nil, nil, err + } + } + // register metainfo into ctx ctx = metainfo.SetMetaInfoFromMap(ctx, st.header) ss := newServerStream(st) + // cancel ctx when OnStreamFinish ctx, cancelFunc := ktx.WithCancel(ctx) ctx = context.WithValue(ctx, serverStreamCancelCtxKey{}, cancelFunc) return ctx, ss, nil @@ -101,7 +123,7 @@ func (s serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Co func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream, err error) (context.Context, error) { sst := ss.(*serverStream) - var exception tException + var exception error if err != nil { switch terr := err.(type) { case kerrors.BizStatusErrorIface: @@ -123,13 +145,17 @@ func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStr if err != nil { return nil, err } - case tException: + exception = nil + case *thrift.ApplicationException: exception = terr + case tException: + exception = thrift.NewApplicationException(terr.TypeId(), terr.Error()) default: exception = thrift.NewApplicationException(remote.InternalError, terr.Error()) } } - if err = sst.close(exception); err != nil { + // server stream CloseSend will send the trailer with payload + if err = sst.CloseSend(exception); err != nil { return nil, err } diff --git a/pkg/streamx/provider/ttstream/server_provider_option.go b/pkg/streamx/provider/ttstream/server_provider_option.go new file mode 100644 index 0000000000..5c2fe3b8cf --- /dev/null +++ b/pkg/streamx/provider/ttstream/server_provider_option.go @@ -0,0 +1,34 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +// ServerProviderOption define server provider options +type ServerProviderOption func(pc *serverProvider) + +// WithServerMetaFrameHandler register TTHeader Streaming meta frame handler +func WithServerMetaFrameHandler(handler MetaFrameHandler) ServerProviderOption { + return func(sp *serverProvider) { + sp.metaHandler = handler + } +} + +// WithServerHeaderFrameHandler register TTHeader Streaming header frame handler +func WithServerHeaderFrameHandler(handler HeaderFrameReadHandler) ServerProviderOption { + return func(sp *serverProvider) { + sp.headerHandler = handler + } +} diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index b650e35bc6..4b4f21ac52 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -23,9 +23,13 @@ import ( "sync/atomic" "time" + "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx" terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" @@ -33,29 +37,36 @@ import ( ) var ( - _ streamx.ClientStream = (*clientStream)(nil) - _ streamx.ServerStream = (*serverStream)(nil) - _ streamx.ClientStreamMetadata = (*clientStream)(nil) - _ streamx.ServerStreamMetadata = (*serverStream)(nil) - _ StreamMeta = (*stream)(nil) + _ streamx.ClientStream = (*clientStream)(nil) + _ streamx.ServerStream = (*serverStream)(nil) + _ StreamMeta = (*stream)(nil) ) -func newStream(trans *transport, mode streamx.StreamingMode, smeta streamFrame) *stream { +func newStream(ctx context.Context, trans *transport, mode streamx.StreamingMode, smeta streamFrame) *stream { s := new(stream) s.streamFrame = smeta + s.StreamMeta = newStreamMeta() + s.reader = newStreamReader() s.trans = trans s.mode = mode s.wheader = make(streamx.Header) s.wtrailer = make(streamx.Trailer) s.headerSig = make(chan int32, 1) s.trailerSig = make(chan int32, 1) - s.StreamMeta = newStreamMeta() + + // register close callback + copts := streamxcallopt.GetCallOptionsFromCtx(ctx) + if copts != nil && len(copts.StreamCloseCallback) > 0 { + s.closeCallback = append(s.closeCallback, copts.StreamCloseCallback...) + } return s } +// streamFrame define a basic stream frame type streamFrame struct { sid int32 method string + meta IntHeader header streamx.Header // key:value, key is full name trailer streamx.Trailer } @@ -64,24 +75,29 @@ const ( streamSigNone int32 = 0 streamSigActive int32 = 1 streamSigInactive int32 = -1 + streamSigCancel int32 = -2 ) +// stream is used to process frames and expose user APIs type stream struct { streamFrame - trans *transport - mode streamx.StreamingMode - wheader streamx.Header // wheader == nil means it already be sent - wtrailer streamx.Trailer // wtrailer == nil means it already be sent - selfEOF int32 - peerEOF int32 + StreamMeta + reader *streamReader + trans *transport + mode streamx.StreamingMode + wheader streamx.Header // wheader == nil means it already be sent + wtrailer streamx.Trailer // wtrailer == nil means it already be sent + headerSig chan int32 trailerSig chan int32 - sio *streamIO - closedFlag int32 // 1 means stream is closed in exception scenario - StreamMeta - metaHandler MetaFrameHandler - recvTimeout time.Duration + selfEOF int32 + peerEOF int32 + eofFlag int32 + + recvTimeout time.Duration + metaFrameHandler MetaFrameHandler + closeCallback []streamxcallopt.StreamCloseCallback } func (s *stream) Mode() streamx.StreamingMode { @@ -99,10 +115,77 @@ func (s *stream) Method() string { return s.method } -// close stream in exception scenario -func (s *stream) close(exception error) { - if !atomic.CompareAndSwapInt32(&s.closedFlag, 0, 1) { - return +func (s *stream) SendMsg(ctx context.Context, msg any) (err error) { + if atomic.LoadInt32(&s.selfEOF) != 0 { + return terrors.ErrIllegalOperation.WithCause(errors.New("stream is close send")) + } + // encode payload + payload, err := EncodePayload(ctx, msg) + if err != nil { + return err + } + // tracing + ri := rpcinfo.GetRPCInfo(ctx) + if ri != nil && ri.Stats() != nil { + if rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()); rpcStats != nil { + rpcStats.IncrSendSize(uint64(len(payload))) + } + } + // send data frame + return s.writeFrame(dataFrameType, nil, nil, payload) +} + +func (s *stream) RecvMsg(ctx context.Context, data any) error { + if s.recvTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, s.recvTimeout) + defer cancel() + } + payload, err := s.reader.output(ctx) + if err != nil { + return err + } + err = DecodePayload(context.Background(), payload, data) + // payload will not be access after decode + mcache.Free(payload) + + // tracing + ri := rpcinfo.GetRPCInfo(ctx) + if ri != nil && ri.Stats() != nil { + if rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()); rpcStats != nil { + rpcStats.IncrRecvSize(uint64(len(payload))) + } + } + return err +} + +// closeSend should be called when following cases happen: +// client: +// - user call CloseSend +// - client recv a trailer +// - transport layer exception +// server: +// - server handler return +// - transport layer exception +func (s *stream) closeSend(exception error) error { + if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { + return nil + } + err := s.sendTrailer(exception) + s.tryRunCloseCallback() + return err +} + +// closeRecv should be called when following cases happen: +// client: +// - transport layer exception +// - client stream is GCed +// server: +// - transport layer exception +// - server handler return +func (s *stream) closeRecv(exception error) error { + if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { + return nil } select { case s.headerSig <- streamSigInactive: @@ -112,45 +195,55 @@ func (s *stream) close(exception error) { case s.trailerSig <- streamSigInactive: default: } - s.sio.close(exception) - s.trans.streamDelete(s.sid) -} - -func (s *stream) isClosed() bool { - return atomic.LoadInt32(&s.closedFlag) == 1 + s.reader.close(exception) + s.tryRunCloseCallback() + return nil } -func (s *stream) isSendFinished() bool { - return atomic.LoadInt32(&s.selfEOF) == 1 +func (s *stream) cancel() error { + if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { + return nil + } + select { + case s.headerSig <- streamSigCancel: + default: + } + select { + case s.trailerSig <- streamSigCancel: + default: + } + s.reader.cancel() + return nil } -func (s *stream) cancel() { - s.sio.cancel() +func (s *stream) setRecvTimeout(timeout time.Duration) { + if timeout <= 0 { + return + } + s.recvTimeout = timeout } -func (s *stream) setMetaFrameHandler(h MetaFrameHandler) { - s.metaHandler = h +func (s *stream) setMetaFrameHandler(metaHandler MetaFrameHandler) { + s.metaFrameHandler = metaHandler } -func (s *stream) readMetaFrame(intHeader IntHeader, header streamx.Header, payload []byte) (err error) { - if s.metaHandler == nil { - return nil +func (s *stream) tryRunCloseCallback() { + if atomic.AddInt32(&s.eofFlag, 1) != 2 { + return } - return s.metaHandler.OnMetaFrame(s.StreamMeta, intHeader, header, payload) + if len(s.closeCallback) > 0 { + for _, cb := range s.closeCallback { + cb() + } + } + s.trans.deleteStream(s.sid) + s.trans.recycle() } -func (s *stream) readHeader(hd streamx.Header) (err error) { - if s.header != nil { - return terrors.ErrUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) - } - s.header = hd - select { - case s.headerSig <- streamSigActive: - default: - return terrors.ErrUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) - } - klog.Debugf("stream[%s] read header: %v", s.method, hd) - return nil +func (s *stream) writeFrame(ftype int32, header streamx.Header, trailer streamx.Trailer, payload []byte) (err error) { + return s.trans.writeFrame( + streamFrame{sid: s.sid, method: s.method, header: header, trailer: trailer}, ftype, payload, + ) } // writeHeader copy kvs into s.wheader @@ -164,55 +257,18 @@ func (s *stream) writeHeader(hd streamx.Header) error { return nil } +// sendHeader send header to peer func (s *stream) sendHeader() (err error) { wheader := s.wheader s.wheader = nil if wheader == nil { return fmt.Errorf("stream header already sent") } - err = s.trans.streamSendHeader(s, wheader) + err = s.writeFrame(headerFrameType, wheader, nil, nil) return err } -// readTrailer by client: unblock recv function and return EOF if no unread frame -// readTrailer by server: unblock recv function and return EOF if no unread frame -func (s *stream) readTrailerFrame(fr *Frame) (err error) { - if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { - return terrors.ErrUnexpectedTrailer.WithCause(fmt.Errorf("content: %v", fr)) - } - - var exception error - // when server-side returns non-biz error, it will be wrapped as ApplicationException stored in trailer frame payload - if len(fr.payload) > 0 { - // exception is type of (*thrift.ApplicationException) - _, _, err = thrift.UnmarshalFastMsg(fr.payload, nil) - exception = terrors.ErrApplicationException.WithCause(err) - } else { - // when server-side returns biz error, payload is empty and biz error information is stored in trailer frame header - bizErr, err := transmeta.ParseBizStatusErr(fr.trailer) - if err != nil { - exception = terrors.ErrIllegalBizErr.WithCause(err) - } else if bizErr != nil { - // bizErr is independent of rpc exception handling - exception = bizErr - } - } - s.trailer = fr.trailer - select { - case s.trailerSig <- streamSigActive: - default: - return terrors.ErrUnexpectedTrailer.WithCause(errors.New("already set trailer")) - } - select { - case s.headerSig <- streamSigNone: - // if trailer arrived, we should return unblock stream.Header() - default: - } - - klog.Debugf("stream[%d] recv trailer: %v, exception: %v", s.sid, s.trailer, exception) - return s.trans.streamCloseRecv(s, exception) -} - +// writeTrailer write trailer to peer func (s *stream) writeTrailer(tl streamx.Trailer) (err error) { if s.wtrailer == nil { return fmt.Errorf("stream trailer already sent") @@ -223,88 +279,94 @@ func (s *stream) writeTrailer(tl streamx.Trailer) (err error) { return nil } -func (s *stream) sendTrailer(ctx context.Context, ex tException) (err error) { - if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { - return nil - } +// writeTrailer send trailer to peer +// if exception is not nil, trailer frame should carry a payload +func (s *stream) sendTrailer(exception error) (err error) { wtrailer := s.wtrailer s.wtrailer = nil if wtrailer == nil { return fmt.Errorf("stream trailer already sent") } - klog.Debugf("transport[%d]-stream[%d] send trailer", s.trans.kind, s.sid) - return s.trans.streamCloseSend(s, wtrailer, ex) -} + klog.Debugf("transport[%d]-stream[%d] send trailer: err=%v", s.trans.kind, s.sid, exception) -func (s *stream) setRecvTimeout(timeout time.Duration) { - if timeout <= 0 { - return + var payload []byte + if exception != nil { + payload, err = EncodeException(context.Background(), s.method, s.sid, exception) + if err != nil { + return err + } } - s.recvTimeout = timeout -} - -func (s *stream) SendMsg(ctx context.Context, res any) (err error) { - err = s.trans.streamSend(ctx, s, res) + err = s.writeFrame(trailerFrameType, nil, wtrailer, payload) return err } -func (s *stream) RecvMsg(ctx context.Context, req any) error { - if s.recvTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, s.recvTimeout) - defer cancel() - } - return s.trans.streamRecv(ctx, s, req) -} - -func newClientStream(s *stream) *clientStream { - cs := &clientStream{stream: s} - return cs -} - -type clientStream struct { - *stream -} - -func (s *clientStream) RecvMsg(ctx context.Context, req any) error { - return s.stream.RecvMsg(ctx, req) -} - -func (s *clientStream) CloseSend(ctx context.Context) error { - return s.sendTrailer(ctx, nil) -} +// === Frame OnRead callback -func newServerStream(s *stream) streamx.ServerStream { - ss := &serverStream{stream: s} - return ss +func (s *stream) onReadMetaFrame(fr *Frame) (err error) { + if s.metaFrameHandler == nil { + return nil + } + return s.metaFrameHandler.OnMetaFrame(s.StreamMeta, fr.meta, fr.header, fr.payload) } -type serverStream struct { - *stream +func (s *stream) onReadHeaderFrame(fr *Frame) (err error) { + if s.header != nil { + return terrors.ErrUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) + } + s.header = fr.header + select { + case s.headerSig <- streamSigActive: + default: + return terrors.ErrUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) + } + klog.Debugf("stream[%s] read header: %v", s.method, fr.header) + return nil } -func (s *serverStream) RecvMsg(ctx context.Context, req any) error { - return s.stream.RecvMsg(ctx, req) +func (s *stream) onReadDataFrame(fr *Frame) (err error) { + s.reader.input(context.Background(), fr.payload) + return nil } -// SendMsg should send left header first -func (s *serverStream) SendMsg(ctx context.Context, res any) error { - if len(s.wheader) > 0 { - if err := s.sendHeader(); err != nil { - return err +// onReadTrailerFrame by client: unblock recv function and return EOF if no unread frame +// onReadTrailerFrame by server: unblock recv function and return EOF if no unread frame +func (s *stream) onReadTrailerFrame(fr *Frame) (err error) { + var exception error + // when server-side returns non-biz error, it will be wrapped as ApplicationException stored in trailer frame payload + if len(fr.payload) > 0 { + // exception is type of (*thrift.ApplicationException) + _, _, err = thrift.UnmarshalFastMsg(fr.payload, nil) + exception = terrors.ErrApplicationException.WithCause(err) + } else if len(fr.trailer) > 0 { + // when server-side returns biz error, payload is empty and biz error information is stored in trailer frame header + bizErr, err := transmeta.ParseBizStatusErr(fr.trailer) + if err != nil { + exception = terrors.ErrIllegalBizErr.WithCause(err) + } else if bizErr != nil { + // bizErr is independent of rpc exception handling + exception = bizErr } } - return s.stream.SendMsg(ctx, res) -} + s.trailer = fr.trailer + select { + case s.trailerSig <- streamSigActive: + default: + } + select { + case s.headerSig <- streamSigNone: + // if trailer arrived, we should return unblock stream.Header() + default: + } -// close will be called after server handler returned -// after close stream cannot be access again -func (s *serverStream) close(ex tException) error { - // write loop should help to delete stream - err := s.sendTrailer(context.Background(), ex) - if err != nil { - return err + klog.Debugf("stream[%d] recv trailer: %v, exception: %v", s.sid, s.trailer, exception) + switch s.trans.kind { + case clientTransport: + // if client recv trailer, server handler must be return, + // so we don't need to send data anymore + err = s.closeRecv(exception) + case serverTransport: + // if server recv trailer, we only need to close recv but still can send data + err = s.closeRecv(exception) } - s.stream.close(ex) - return nil + return err } diff --git a/pkg/streamx/provider/ttstream/stream_header_trailer.go b/pkg/streamx/provider/ttstream/stream_client.go similarity index 63% rename from pkg/streamx/provider/ttstream/stream_header_trailer.go rename to pkg/streamx/provider/ttstream/stream_client.go index 611c12a0ca..28783a2a28 100644 --- a/pkg/streamx/provider/ttstream/stream_header_trailer.go +++ b/pkg/streamx/provider/ttstream/stream_client.go @@ -17,25 +17,34 @@ package ttstream import ( + "context" "errors" "github.com/cloudwego/kitex/pkg/streamx" ) -var ( - _ ClientStreamMeta = (*clientStream)(nil) - _ ServerStreamMeta = (*serverStream)(nil) -) +var _ ClientStreamMeta = (*clientStream)(nil) + +func newClientStream(s *stream) *clientStream { + cs := &clientStream{stream: s} + return cs +} + +type clientStream struct { + *stream +} func (s *clientStream) Header() (streamx.Header, error) { sig := <-s.headerSig switch sig { - case streamSigActive: - return s.header, nil case streamSigNone: return make(streamx.Header), nil + case streamSigActive: + return s.header, nil case streamSigInactive: return nil, ErrClosedStream + case streamSigCancel: + return nil, ErrCanceledStream } return nil, errors.New("invalid stream signal") } @@ -43,27 +52,29 @@ func (s *clientStream) Header() (streamx.Header, error) { func (s *clientStream) Trailer() (streamx.Trailer, error) { sig := <-s.trailerSig switch sig { - case streamSigActive: - return s.trailer, nil case streamSigNone: return make(streamx.Trailer), nil + case streamSigActive: + return s.trailer, nil case streamSigInactive: return nil, ErrClosedStream + case streamSigCancel: + return nil, ErrCanceledStream } return nil, errors.New("invalid stream signal") } -func (s *serverStream) SetHeader(hd streamx.Header) error { - return s.writeHeader(hd) +func (s *clientStream) RecvMsg(ctx context.Context, req any) error { + return s.stream.RecvMsg(ctx, req) } -func (s *serverStream) SendHeader(hd streamx.Header) error { - if err := s.writeHeader(hd); err != nil { - return err - } - return s.stream.sendHeader() +// CloseSend by clientStream only send trailer frame and will not close the stream +func (s *clientStream) CloseSend(ctx context.Context) error { + return s.closeSend(nil) } -func (s *serverStream) SetTrailer(tl streamx.Trailer) error { - return s.writeTrailer(tl) +func clientStreamFinalizer(s *clientStream) { + // it's safe to call CloseSend twice + // we do CloseSend here to ensure stream can be closed normally + _ = s.CloseSend(context.Background()) } diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go deleted file mode 100644 index 78b84990d8..0000000000 --- a/pkg/streamx/provider/ttstream/stream_io.go +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ttstream - -import ( - "context" - "errors" - "io" - "sync/atomic" - - "github.com/cloudwego/kitex/pkg/klog" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" -) - -type streamIOMsg struct { - payload []byte - exception error -} - -type streamIO struct { - ctx context.Context - pipe *container.Pipe[streamIOMsg] - cache [1]streamIOMsg - exception error // once has exception, the stream should not work normally again - eofFlag int32 - callbackFlag int32 - closeCallback func(ctx context.Context) -} - -func newStreamIO(ctx context.Context) *streamIO { - sio := new(streamIO) - sio.ctx = ctx - sio.pipe = container.NewPipe[streamIOMsg]() - return sio -} - -func (s *streamIO) input(ctx context.Context, msg streamIOMsg) { - err := s.pipe.Write(ctx, msg) - if err != nil { - klog.Errorf("pipe write failed: %v", err) - } -} - -func (s *streamIO) output(ctx context.Context) (msg streamIOMsg, err error) { - if s.exception != nil { - return msg, s.exception - } - - n, err := s.pipe.Read(ctx, s.cache[:]) - if err != nil { - if errors.Is(err, container.ErrPipeEOF) { - err = io.EOF - } - s.exception = err - return msg, s.exception - } - if n == 0 { - s.exception = io.EOF - return msg, s.exception - } - msg = s.cache[0] - if msg.exception != nil { - s.exception = msg.exception - return msg, s.exception - } - return msg, nil -} - -func (s *streamIO) runCloseCallback() { - if s.closeCallback != nil && atomic.CompareAndSwapInt32(&s.callbackFlag, 0, 1) { - s.closeCallback(s.ctx) - } -} - -func (s *streamIO) closeRecv() { - s.pipe.Close() - if s.closeCallback != nil && atomic.AddInt32(&s.eofFlag, 1) == 2 { - s.runCloseCallback() - } -} - -func (s *streamIO) closeSend() { - if s.closeCallback != nil && atomic.AddInt32(&s.eofFlag, 1) == 2 { - s.runCloseCallback() - } -} - -func (s *streamIO) cancel() { - s.pipe.Cancel() - s.runCloseCallback() -} - -func (s *streamIO) close(exception error) { - if exception != nil { - s.input(context.Background(), streamIOMsg{exception: exception}) - } - s.pipe.Close() - if flag := atomic.AddInt32(&s.eofFlag, 2); (flag == 2 || flag == 3) && s.closeCallback != nil { - s.runCloseCallback() - } -} diff --git a/pkg/streamx/provider/ttstream/stream_reader.go b/pkg/streamx/provider/ttstream/stream_reader.go new file mode 100644 index 0000000000..fe23ba7086 --- /dev/null +++ b/pkg/streamx/provider/ttstream/stream_reader.go @@ -0,0 +1,87 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "errors" + "io" + + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" +) + +// streamReader is an abstraction layer for stream level IO operations +type streamReader struct { + pipe *container.Pipe[streamMsg] + cache [1]streamMsg + exception error // once has exception, the stream should not work normally again +} + +type streamMsg struct { + payload []byte + exception error +} + +func newStreamReader() *streamReader { + sio := new(streamReader) + sio.pipe = container.NewPipe[streamMsg]() + return sio +} + +func (s *streamReader) input(ctx context.Context, payload []byte) { + err := s.pipe.Write(ctx, streamMsg{payload: payload}) + if err != nil { + klog.Errorf("pipe write failed: %v", err) + } +} + +func (s *streamReader) output(ctx context.Context) (payload []byte, err error) { + if s.exception != nil { + return nil, s.exception + } + + n, err := s.pipe.Read(ctx, s.cache[:]) + if err != nil { + if errors.Is(err, container.ErrPipeEOF) { + err = io.EOF + } + s.exception = err + return nil, s.exception + } + if n == 0 { + s.exception = io.EOF + return nil, s.exception + } + msg := s.cache[0] + if msg.exception != nil { + s.exception = msg.exception + return nil, s.exception + } + return msg.payload, nil +} + +func (s *streamReader) cancel() { + s.pipe.Cancel() +} + +func (s *streamReader) close(exception error) { + if exception != nil { + _ = s.pipe.Write(context.Background(), streamMsg{exception: exception}) + } + s.pipe.Close() +} diff --git a/pkg/streamx/provider/ttstream/stream_reader_test.go b/pkg/streamx/provider/ttstream/stream_reader_test.go new file mode 100644 index 0000000000..b6d831ac85 --- /dev/null +++ b/pkg/streamx/provider/ttstream/stream_reader_test.go @@ -0,0 +1,75 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestStreamReader(t *testing.T) { + ctx := context.Background() + msg := []byte("hello world") + round := 10000 + + // basic IOs + sio := newStreamReader() + var done int32 + go func() { + for i := 0; i < round; i++ { + sio.input(ctx, msg) + } + atomic.StoreInt32(&done, 1) + }() + for i := 0; i < round; i++ { + payload, err := sio.output(ctx) + test.Assert(t, err == nil, err) + test.DeepEqual(t, msg, payload) + } + test.Assert(t, atomic.LoadInt32(&done) == int32(1)) + + // exception IOs + sio.input(ctx, msg) + targetErr := errors.New("test") + sio.close(targetErr) + payload, err := sio.output(ctx) + test.Assert(t, err == nil, err) + test.DeepEqual(t, msg, payload) + payload, err = sio.output(ctx) + test.Assert(t, payload == nil, payload) + test.Assert(t, errors.Is(err, targetErr), err) + + // ctx canceled IOs + ctx, cancel := context.WithCancel(ctx) + sio = newStreamReader() + sio.input(ctx, msg) + go func() { + time.Sleep(time.Millisecond * 10) + cancel() + }() + payload, err = sio.output(ctx) + test.Assert(t, err == nil, err) + test.DeepEqual(t, msg, payload) + payload, err = sio.output(ctx) + test.Assert(t, payload == nil, payload) + test.Assert(t, errors.Is(err, context.Canceled), err) +} diff --git a/pkg/streamx/provider/ttstream/stream_server.go b/pkg/streamx/provider/ttstream/stream_server.go new file mode 100644 index 0000000000..825a2dd0e8 --- /dev/null +++ b/pkg/streamx/provider/ttstream/stream_server.go @@ -0,0 +1,73 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + + "github.com/cloudwego/kitex/pkg/streamx" +) + +var _ ServerStreamMeta = (*serverStream)(nil) + +func newServerStream(s *stream) *serverStream { + ss := &serverStream{stream: s} + return ss +} + +type serverStream struct { + *stream +} + +func (s *serverStream) SetHeader(hd streamx.Header) error { + return s.writeHeader(hd) +} + +func (s *serverStream) SendHeader(hd streamx.Header) error { + if err := s.writeHeader(hd); err != nil { + return err + } + return s.stream.sendHeader() +} + +func (s *serverStream) SetTrailer(tl streamx.Trailer) error { + return s.writeTrailer(tl) +} + +func (s *serverStream) RecvMsg(ctx context.Context, req any) error { + return s.stream.RecvMsg(ctx, req) +} + +// SendMsg should send left header first +func (s *serverStream) SendMsg(ctx context.Context, res any) error { + if s.wheader != nil { + if err := s.sendHeader(); err != nil { + return err + } + } + return s.stream.SendMsg(ctx, res) +} + +// CloseSend by serverStream will be called after server handler returned +// after CloseSend stream cannot be access again +func (s *serverStream) CloseSend(exception error) error { + err := s.closeSend(exception) + if err != nil { + return err + } + return err +} diff --git a/pkg/streamx/provider/ttstream/stream_test.go b/pkg/streamx/provider/ttstream/stream_test.go new file mode 100644 index 0000000000..bee94df441 --- /dev/null +++ b/pkg/streamx/provider/ttstream/stream_test.go @@ -0,0 +1,56 @@ +//go:build !windows + +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestGenericStreaming(t *testing.T) { + cs, ss, err := newTestStreamPipe(testServiceInfo, "Bidi") + test.Assert(t, err == nil, err) + ctx := context.Background() + + // raw struct + req := new(testRequest) + req.A = 123 + req.B = "hello" + err = cs.SendMsg(ctx, req) + test.Assert(t, err == nil, err) + res := new(testResponse) + err = ss.RecvMsg(ctx, res) + test.Assert(t, err == nil, err) + test.Assert(t, res.A == req.A) + test.Assert(t, res.B == req.B) + + // map generic + // client side + // reqJSON := `{"A":123, "b":"hello"}` + // err = cs.SendMsg(ctx, reqJSON) + // test.Assert(t, err == nil, err) + // server side + // res = new(testResponse) + // err = ss.RecvMsg(ctx, res) + // test.Assert(t, err == nil, err) + // test.Assert(t, res.A == req.A) + // test.Assert(t, res.B == req.B) +} diff --git a/pkg/streamx/provider/ttstream/test_utils.go b/pkg/streamx/provider/ttstream/test_utils.go new file mode 100644 index 0000000000..e78ad68738 --- /dev/null +++ b/pkg/streamx/provider/ttstream/test_utils.go @@ -0,0 +1,55 @@ +//go:build !windows + +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" +) + +func newTestStreamPipe(sinfo *serviceinfo.ServiceInfo, method string) (*clientStream, *serverStream, error) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + if err != nil { + return nil, nil, err + } + sconn, err := netpoll.NewFDConnection(sfd) + if err != nil { + return nil, nil, err + } + + intHeader := make(IntHeader) + strHeader := make(streamx.Header) + ctrans := newTransport(clientTransport, sinfo, cconn, nil) + rawClientStream, err := ctrans.WriteStream(context.Background(), method, intHeader, strHeader) + if err != nil { + return nil, nil, err + } + strans := newTransport(serverTransport, sinfo, sconn, nil) + rawServerStream, err := strans.ReadStream(context.Background()) + if err != nil { + return nil, nil, err + } + + return newClientStream(rawClientStream), newServerStream(rawServerStream), nil +} diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index c5e08be6c2..ad4f84c0cd 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -21,19 +21,14 @@ import ( "errors" "fmt" "io" + "net" "sync" "sync/atomic" "time" - "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/gopkg/bufiox" - "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/netpoll" - "github.com/cloudwego/kitex/pkg/rpcinfo" - - "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" - "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" @@ -52,21 +47,22 @@ func isIgnoreError(err error) bool { return errors.Is(err, netpoll.ErrEOF) || errors.Is(err, io.EOF) || errors.Is(err, netpoll.ErrConnClosed) } +// transport is used to read/write frames and disturbed frames to different streams type transport struct { kind int32 sinfo *serviceinfo.ServiceInfo conn netpoll.Connection + pool transPool // transport should operate directly on stream streams sync.Map // key=streamID val=stream scache []*stream // size is streamCacheSize spipe *container.Pipe[*stream] // in-coming stream pipe fpipe *container.Pipe[*Frame] // out-coming frame pipe closedFlag int32 - streamingFlag int32 // flag == 0 means there is no active stream on transport closedTrigger chan struct{} } -func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection) *transport { +func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection, pool transPool) *transport { // stream max idle session is 10 minutes. // TODO: let it configurable _ = conn.SetReadTimeout(time.Minute * 10) @@ -74,6 +70,7 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne kind: kind, sinfo: sinfo, conn: conn, + pool: pool, streams: sync.Map{}, spipe: container.NewPipe[*stream](), scache: make([]*stream, 0, streamCacheSize), @@ -87,7 +84,9 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne err := t.loopRead() if err != nil { if !isIgnoreError(err) { - klog.Warnf("transport[%d] loop read err: %v", t.kind, err) + klog.Warnf("transport[%d-%s] loop read err: %v", t.kind, t.Addr(), err) + } else { + klog.Debugf("transport[%d-%s] loop read err: %v", t.kind, t.Addr(), err) } // if connection is closed by peer, loop read should return ErrConnClosed error, // so we should close transport here @@ -101,7 +100,9 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne err := t.loopWrite() if err != nil { if !isIgnoreError(err) { - klog.Warnf("transport[%d] loop write err: %v", t.kind, err) + klog.Warnf("transport[%d-%s] loop write err: %v", t.kind, t.Addr(), err) + } else { + klog.Debugf("transport[%d-%s] loop write err: %v", t.kind, t.Addr(), err) } _ = t.Close(err) } @@ -109,6 +110,16 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne return t } +func (t *transport) Addr() net.Addr { + switch t.kind { + case clientTransport: + return t.conn.LocalAddr() + case serverTransport: + return t.conn.RemoteAddr() + } + return nil +} + // Close will close transport and destroy all resource and goroutines // server close transport when connection is disconnected // client close transport when transPool discard the transport @@ -118,17 +129,13 @@ func (t *transport) Close(exception error) (err error) { if !atomic.CompareAndSwapInt32(&t.closedFlag, 0, 1) { return nil } - switch t.kind { - case clientTransport: - klog.Debugf("transport[%d-%s] is closing", t.kind, t.conn.LocalAddr()) - case serverTransport: - klog.Debugf("transport[%d-%s] is closing", t.kind, t.conn.RemoteAddr()) - } + klog.Debugf("transport[%d-%s] is closing", t.kind, t.Addr()) t.spipe.Close() t.fpipe.Close() t.streams.Range(func(key, value any) bool { s := value.(*stream) - s.close(exception) + _ = s.closeSend(exception) + _ = s.closeRecv(exception) return true }) return err @@ -143,12 +150,8 @@ func (t *transport) IsActive() bool { return atomic.LoadInt32(&t.closedFlag) == 0 && t.conn.IsActive() } -func (t *transport) storeStream(ctx context.Context, s *stream) { - s.sio = newStreamIO(ctx) - copts := streamxcallopt.GetCallOptionsFromCtx(ctx) - if copts != nil && copts.StreamCloseCallback != nil { - s.sio.closeCallback = copts.StreamCloseCallback - } +func (t *transport) storeStream(s *stream) { + klog.Debugf("transport[%d-%s] store stream: sid=%d", t.kind, t.Addr(), s.sid) t.streams.Store(s.sid, s) } @@ -161,63 +164,59 @@ func (t *transport) loadStream(sid int32) (s *stream, ok bool) { return s, true } +func (t *transport) deleteStream(sid int32) { + klog.Debugf("transport[%d-%s] delete stream: sid=%d", t.kind, t.Addr(), sid) + // remove stream from transport + t.streams.Delete(sid) +} + +func (t *transport) recycle() { + if t.pool != nil { + t.pool.Put(t) + } +} + func (t *transport) readFrame(reader bufiox.Reader) error { fr, err := DecodeFrame(context.Background(), reader) if err != nil { return err } defer recycleFrame(fr) - klog.Debugf("transport[%d] DecodeFrame: fr=%v", t.kind, fr) - - switch fr.typ { - case metaFrameType: - s, ok := t.loadStream(fr.sid) - if ok { - err = s.readMetaFrame(fr.meta, fr.header, fr.payload) - } else { - klog.Errorf("transport[%d] read a unknown stream meta: sid=%d", t.kind, fr.sid) - } - case headerFrameType: - switch t.kind { - case serverTransport: - // Header Frame: server recv a new stream - smode := t.sinfo.MethodInfo(fr.method).StreamingMode() - s := newStream(t, smode, fr.streamFrame) - t.storeStream(context.Background(), s) - err = t.spipe.Write(context.Background(), s) - case clientTransport: - // Header Frame: client recv header - s, ok := t.loadStream(fr.sid) - if ok { - if sErr := s.readHeader(fr.header); sErr != nil { - s.close(sErr) - } - } else { - klog.Errorf("transport[%d] read a unknown stream header: sid=%d header=%v", - t.kind, fr.sid, fr.header) - } - } - case dataFrameType: - // Data Frame: decode and distribute data - s, ok := t.loadStream(fr.sid) - if ok { - s.sio.input(context.Background(), streamIOMsg{payload: fr.payload}) + klog.Debugf("transport[%d] DecodeFrame: frame=%s", t.kind, fr) + + var s *stream + if fr.typ == headerFrameType && t.kind == serverTransport { + // server recv a header frame, we should create a new stream + smode := t.sinfo.MethodInfo(fr.method).StreamingMode() + s = newStream(context.Background(), t, smode, fr.streamFrame) + t.storeStream(s) + err = t.spipe.Write(context.Background(), s) + } else { + // load exist stream + var ok bool + s, ok = t.loadStream(fr.sid) + if !ok { + klog.Errorf( + "transport[%d] read a unknown stream: frame[%s]", + t.kind, fr.String(), + ) + // ignore unknown stream error + err = nil } else { - klog.Errorf("transport[%d] read a unknown stream data: sid=%d", t.kind, fr.sid) - } - case trailerFrameType: - // Trailer Frame: recv trailer, Close read direction - s, ok := t.loadStream(fr.sid) - if ok { - if sErr := s.readTrailerFrame(fr); sErr != nil { - s.close(sErr) + // process different frames + switch fr.typ { + case metaFrameType: + err = s.onReadMetaFrame(fr) + case headerFrameType: + // process header frame for client transport + err = s.onReadHeaderFrame(fr) + case dataFrameType: + // process data frame: decode and distribute data + err = s.onReadDataFrame(fr) + case trailerFrameType: + // process trailer frame: close the stream read direction + err = s.onReadTrailerFrame(fr) } - } else { - // client recv an unknown trailer is in exception, - // because the client stream may already be GCed, - // but the connection is still active so peer server can send a trailer - klog.Errorf("transport[%d] read a unknown stream trailer: sid=%d trailer=%v", - t.kind, fr.sid, fr.trailer) } } return err @@ -254,10 +253,13 @@ func (t *transport) loopWrite() error { } for i := 0; i < n; i++ { fr := fcache[i] - klog.Debugf("transport[%d] EncodeFrame: fr=%v IsActive=%v", t.kind, fr, t.conn.IsActive()) + klog.Debugf("transport[%d] EncodeFrame: fr=%s", t.kind, fr) if err = EncodeFrame(context.Background(), writer, fr); err != nil { return err } + if err = t.conn.Writer().Flush(); err != nil { + return err + } recycleFrame(fr) } if err = t.conn.Writer().Flush(); err != nil { @@ -267,113 +269,17 @@ func (t *transport) loopWrite() error { } // writeFrame is concurrent safe -func (t *transport) writeFrame(sframe streamFrame, meta IntHeader, ftype int32, payload []byte) (err error) { - frame := newFrame(sframe, meta, ftype, payload) +func (t *transport) writeFrame(sframe streamFrame, ftype int32, payload []byte) (err error) { + frame := newFrame(sframe, ftype, payload) return t.fpipe.Write(context.Background(), frame) } -func (t *transport) streamSend(ctx context.Context, s *stream, res any) (err error) { - if s.isClosed() { - return s.sio.exception - } - if s.isSendFinished() { - return io.EOF - } - if len(s.wheader) > 0 { - err = t.streamSendHeader(s, s.wheader) - if err != nil { - return err - } - } - payload, err := EncodePayload(ctx, res) - if err != nil { - return err - } - // tracing - ri := rpcinfo.GetRPCInfo(ctx) - if ri != nil && ri.Stats() != nil { - if rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()); rpcStats != nil { - rpcStats.IncrSendSize(uint64(len(payload))) - } - } - return t.writeFrame( - streamFrame{sid: s.sid, method: s.method}, - nil, dataFrameType, payload, - ) -} - -func (t *transport) streamSendHeader(s *stream, header streamx.Header) (err error) { - return t.writeFrame( - streamFrame{sid: s.sid, method: s.method, header: header}, - nil, headerFrameType, nil) -} - -func (t *transport) streamCloseSend(s *stream, trailer streamx.Trailer, exception tException) (err error) { - var payload []byte - if exception != nil { - payload, err = EncodeException(context.Background(), s.method, s.sid, exception) - if err != nil { - return err - } - } - err = t.writeFrame( - streamFrame{sid: s.sid, method: s.method, trailer: trailer}, - nil, trailerFrameType, payload, - ) - if err != nil { - return err - } - s.sio.closeSend() - return nil -} - -func (t *transport) streamRecv(ctx context.Context, s *stream, data any) (err error) { - msg, err := s.sio.output(ctx) - if err != nil { - return err - } - err = DecodePayload(context.Background(), msg.payload, data.(thrift.FastCodec)) - // payload will not be access after decode - mcache.Free(msg.payload) - - // tracing - ri := rpcinfo.GetRPCInfo(ctx) - if ri != nil && ri.Stats() != nil { - if rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()); rpcStats != nil { - rpcStats.IncrRecvSize(uint64(len(msg.payload))) - } - } - return err -} - -func (t *transport) streamCloseRecv(s *stream, exception error) error { - if exception != nil { - s.close(exception) - } else { - s.sio.closeRecv() - } - return nil -} - -func (t *transport) streamDelete(sid int32) { - // remove stream from transport - _, ok := t.streams.LoadAndDelete(sid) - if !ok { - return - } - atomic.AddInt32(&t.streamingFlag, -1) -} - -func (t *transport) IsStreaming() bool { - return atomic.LoadInt32(&t.streamingFlag) > 0 -} - var clientStreamID int32 -// newStream create new stream on current connection +// WriteStream create new stream on current connection // it's typically used by client side // newStream is concurrency safe -func (t *transport) newStream( +func (t *transport) WriteStream( ctx context.Context, method string, intHeader IntHeader, strHeader streamx.Header, ) (*stream, error) { if t.kind != clientTransport { @@ -382,23 +288,22 @@ func (t *transport) newStream( sid := atomic.AddInt32(&clientStreamID, 1) smode := t.sinfo.MethodInfo(method).StreamingMode() - // create stream + // new stream first + s := newStream(ctx, t, smode, streamFrame{sid: sid, method: method}) + t.storeStream(s) + // send create stream request for server err := t.writeFrame( - streamFrame{sid: sid, method: method, header: strHeader}, - intHeader, headerFrameType, nil, + streamFrame{sid: sid, method: method, header: strHeader, meta: intHeader}, headerFrameType, nil, ) if err != nil { return nil, err } - s := newStream(t, smode, streamFrame{sid: sid, method: method}) - t.storeStream(ctx, s) - atomic.AddInt32(&t.streamingFlag, 1) return s, nil } -// readStream wait for a new incoming stream on current connection +// ReadStream wait for a new incoming stream on current connection // it's typically used by server side -func (t *transport) readStream(ctx context.Context) (*stream, error) { +func (t *transport) ReadStream(ctx context.Context) (*stream, error) { if t.kind != serverTransport { return nil, fmt.Errorf("transport already be used as other kind") } @@ -406,7 +311,6 @@ READ: if len(t.scache) > 0 { s := t.scache[len(t.scache)-1] t.scache = t.scache[:len(t.scache)-1] - atomic.AddInt32(&t.streamingFlag, 1) return s, nil } n, err := t.spipe.Read(ctx, t.scache[0:streamCacheSize]) diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go new file mode 100644 index 0000000000..444206a00d --- /dev/null +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -0,0 +1,230 @@ +//go:build !windows + +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "errors" + "io" + "strings" + "sync" + "testing" + + "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/serviceinfo" + "github.com/cloudwego/kitex/pkg/streamx" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" + "github.com/cloudwego/kitex/server/streamxserver" +) + +var testServiceInfo = &serviceinfo.ServiceInfo{ + ServiceName: "kitex.service.streaming", + Methods: map[string]serviceinfo.MethodInfo{ + "Bidi": serviceinfo.NewMethodInfo( + func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { + return streamxserver.InvokeStream[testRequest, testResponse]( + ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + }, + nil, + nil, + false, + serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + ), + }, + Extra: map[string]interface{}{"streaming": true, "streamx": true}, +} + +func TestTransportBasic(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + intHeader[0] = "test" + strHeader := make(streamx.Header) + strHeader["key"] = "val" + ctrans := newTransport(clientTransport, testServiceInfo, cconn, nil) + rawClientStream, err := ctrans.WriteStream(context.Background(), "Bidi", intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newTransport(serverTransport, testServiceInfo, sconn, nil) + rawServerStream, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + cs := newClientStream(rawClientStream) + req := new(testRequest) + req.B = "hello" + err := cs.SendMsg(context.Background(), req) + test.Assert(t, err == nil, err) + t.Logf("client stream send msg: %v", req) + + res := new(testResponse) + err = cs.RecvMsg(context.Background(), res) + t.Logf("client stream recv msg: %v", res) + test.Assert(t, err == nil, err) + test.DeepEqual(t, req.B, res.B) + + hd, err := cs.Header() + test.Assert(t, err == nil, err) + test.DeepEqual(t, hd["key"], strHeader["key"]) + t.Logf("client stream recv header: %v", hd) + + err = cs.CloseSend(context.Background()) + test.Assert(t, err == nil, err) + t.Logf("client stream close send") + }() + + // server + ss := newServerStream(rawServerStream) + err = ss.SendHeader(strHeader) + test.Assert(t, err == nil, err) + t.Logf("server stream send header: %v", strHeader) + + req := new(testRequest) + err = ss.RecvMsg(context.Background(), req) + test.Assert(t, err == nil, err) + t.Logf("server stream recv msg: %v", req) + res := new(testResponse) + res.B = req.B + err = ss.SendMsg(context.Background(), res) + test.Assert(t, err == nil, err) + t.Logf("server stream send msg: %v", req) + err = ss.RecvMsg(context.Background(), req) + test.Assert(t, err == io.EOF, err) + t.Logf("server stream recv msg: %v", res) + err = ss.CloseSend(nil) + test.Assert(t, err == nil, err) + t.Log("server handler return") + wg.Wait() +} + +func TestTransportServerStreaming(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streamx.Header) + ctrans := newTransport(clientTransport, testServiceInfo, cconn, nil) + rawClientStream, err := ctrans.WriteStream(context.Background(), "Bidi", intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newTransport(serverTransport, testServiceInfo, sconn, nil) + rawServerStream, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + cs := newClientStream(rawClientStream) + req := new(testRequest) + req.B = "hello" + err := cs.SendMsg(context.Background(), req) + test.Assert(t, err == nil, err) + t.Logf("client stream send msg: %v", req) + + err = cs.CloseSend(context.Background()) + test.Assert(t, err == nil, err) + t.Logf("client stream close send") + + res := new(testResponse) + for { + err = cs.RecvMsg(context.Background(), res) + if err == io.EOF { + break + } + test.Assert(t, err == nil, err) + } + }() + + // server + ss := newServerStream(rawServerStream) + err = ss.SendHeader(strHeader) + test.Assert(t, err == nil, err) + t.Logf("server stream send header: %v", strHeader) + + req := new(testRequest) + err = ss.RecvMsg(context.Background(), req) + test.Assert(t, err == nil, err) + t.Logf("server stream recv msg: %v", req) + for i := 0; i < 3; i++ { + res := new(testResponse) + res.B = req.B + err = ss.SendMsg(context.Background(), res) + test.Assert(t, err == nil, err) + t.Logf("server stream send msg: %v", req) + } + err = ss.CloseSend(nil) + test.Assert(t, err == nil, err) + t.Log("server handler return") + wg.Wait() +} + +func TestTransportException(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + ctrans := newTransport(clientTransport, testServiceInfo, cconn, nil) + rawClientStream, err := ctrans.WriteStream(context.Background(), "Bidi", make(IntHeader), make(streamx.Header)) + test.Assert(t, err == nil, err) + strans := newTransport(serverTransport, testServiceInfo, sconn, nil) + rawServerStream, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + // server send exception + ss := newServerStream(rawServerStream) + targetException := thrift.NewApplicationException(remote.InternalError, "test") + err = ss.CloseSend(targetException) + test.Assert(t, err == nil, err) + // client recv exception + cs := newClientStream(rawClientStream) + res := new(testResponse) + err = cs.RecvMsg(context.Background(), res) + test.Assert(t, err != nil, err) + test.Assert(t, strings.Contains(err.Error(), targetException.Msg()), err.Error()) + + // server send illegal frame + rawClientStream, err = ctrans.WriteStream(context.Background(), "Bidi", make(IntHeader), make(streamx.Header)) + test.Assert(t, err == nil, err) + rawServerStream, err = strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + test.Assert(t, rawServerStream != nil, rawServerStream) + _, err = sconn.Write([]byte("helloxxxxxxxxxxxxxxxxxxxxxx")) + test.Assert(t, err == nil, err) + cs = newClientStream(rawClientStream) + err = cs.RecvMsg(context.Background(), res) + test.Assert(t, err != nil, err) + t.Logf("client stream send msg: %v %v", err, errors.Is(err, terrors.ErrIllegalFrame)) +} diff --git a/pkg/streamx/server_provider.go b/pkg/streamx/server_provider.go index d044853172..1b7f2d8edc 100644 --- a/pkg/streamx/server_provider.go +++ b/pkg/streamx/server_provider.go @@ -54,11 +54,13 @@ stream.Close() - server handler return */ +// ServerProvider define server provider API type ServerProvider interface { // Available detect if provider can process conn from its first N bytes Available(ctx context.Context, conn net.Conn) bool // ProtocolMath - // OnActive called when conn connected + // OnActive called when conn accepted OnActive(ctx context.Context, conn net.Conn) (context.Context, error) + // OnInactive called then conn disconnect OnInactive(ctx context.Context, conn net.Conn) (context.Context, error) // OnStream should read conn data and return a server stream OnStream(ctx context.Context, conn net.Conn) (context.Context, ServerStream, error) diff --git a/pkg/streamx/server_provider_internal.go b/pkg/streamx/server_provider_internal.go index 643753888b..f8d4be0c9e 100644 --- a/pkg/streamx/server_provider_internal.go +++ b/pkg/streamx/server_provider_internal.go @@ -21,6 +21,7 @@ import ( "net" ) +// NewServerProvider wrap specific server provider func NewServerProvider(ss ServerProvider) ServerProvider { if _, ok := ss.(*internalServerProvider); ok { return ss diff --git a/pkg/streamx/stream.go b/pkg/streamx/stream.go index e478a65024..5c30d399d1 100644 --- a/pkg/streamx/stream.go +++ b/pkg/streamx/stream.go @@ -56,6 +56,7 @@ client.CloseAndRecv(res) === EOF ==> server.Recv(EOF) ------------------- [Server Streaming] ------------------- ---------- (Request) returns (stream Response) ---------- client.Send(req) === req ==> server.Recv(req) +client.CloseSend() === EOF ==> server.Recv(EOF) client.Recv(res) <== res === server.Send(req) ... client.Recv(res) <== res === server.Send(req) @@ -85,6 +86,7 @@ const ( StreamingBidirectional = serviceinfo.StreamingBidirectional ) +// Stream define stream APIs type Stream interface { Mode() StreamingMode Service() string @@ -93,85 +95,85 @@ type Stream interface { RecvMsg(ctx context.Context, m any) error } +// ClientStream define client stream APIs type ClientStream interface { Stream + ClientStreamMetadata CloseSend(ctx context.Context) error } +// ServerStream define server stream APIs type ServerStream interface { Stream + ServerStreamMetadata } -// client 必须通过 metainfo.WithValue(ctx, ..) 给下游传递信息 -// client 必须通过 metainfo.GetValue(ctx, ..) 拿到当前 server 的透传信息 -// client 必须通过 Header() 拿到下游 server 的透传信息 +// ClientStreamMetadata define metainfo getter API +// client should use metainfo.WithValue(ctx, ..) to transmit metainfo to server +// client should use Header() to get metainfo from server +// client should use metainfo.GetValue(ctx, ..) get current server handler's metainfo type ClientStreamMetadata interface { Header() (Header, error) Trailer() (Trailer, error) } -// server 可以通过 Set/SendXXX 给上游回传信息 +// ServerStreamMetadata define metainfo setter API +// server should use SetHeader/SendHeader/SetTrailer to transmit metainfo to client type ServerStreamMetadata interface { SetHeader(hd Header) error SendHeader(hd Header) error SetTrailer(hd Trailer) error } +// ServerStreamingClient define client side server streaming APIs type ServerStreamingClient[Res any] interface { Recv(ctx context.Context) (*Res, error) ClientStream - ClientStreamMetadata } +// ServerStreamingServer define server side server streaming APIs type ServerStreamingServer[Res any] interface { Send(ctx context.Context, res *Res) error ServerStream - ServerStreamMetadata } +// ClientStreamingClient define client side client streaming APIs type ClientStreamingClient[Req, Res any] interface { Send(ctx context.Context, req *Req) error CloseAndRecv(ctx context.Context) (*Res, error) ClientStream - ClientStreamMetadata } +// ClientStreamingServer define server side client streaming APIs type ClientStreamingServer[Req, Res any] interface { Recv(ctx context.Context) (*Req, error) - // SendAndClose(ctx context.Context, res *Res) error ServerStream - ServerStreamMetadata } +// BidiStreamingClient define client side bidi streaming APIs type BidiStreamingClient[Req, Res any] interface { Send(ctx context.Context, req *Req) error Recv(ctx context.Context) (*Res, error) ClientStream - ClientStreamMetadata } +// BidiStreamingServer define server side bidi streaming APIs type BidiStreamingServer[Req, Res any] interface { Recv(ctx context.Context) (*Req, error) Send(ctx context.Context, res *Res) error ServerStream - ServerStreamMetadata -} - -type GenericStreamIOMiddlewareSetter interface { - SetStreamSendEndpoint(e StreamSendEndpoint) - SetStreamRecvEndpoint(e StreamSendEndpoint) } +// NewGenericClientStream return a generic client stream func NewGenericClientStream[Req, Res any](cs ClientStream) *GenericClientStream[Req, Res] { return &GenericClientStream[Req, Res]{ - ClientStream: cs, - ClientStreamMetadata: cs.(ClientStreamMetadata), + ClientStream: cs, } } +// GenericClientStream wrap stream IO with Send/Recv middlewares type GenericClientStream[Req, Res any] struct { ClientStream - ClientStreamMetadata StreamSendMiddleware StreamRecvMiddleware } @@ -219,16 +221,16 @@ func (x *GenericClientStream[Req, Res]) CloseAndRecv(ctx context.Context) (*Res, return x.Recv(ctx) } +// NewGenericServerStream return generic server stream func NewGenericServerStream[Req, Res any](ss ServerStream) *GenericServerStream[Req, Res] { return &GenericServerStream[Req, Res]{ - ServerStream: ss, - ServerStreamMetadata: ss.(ServerStreamMetadata), + ServerStream: ss, } } +// GenericServerStream wrap stream IO with Send/Recv middlewares type GenericServerStream[Req, Res any] struct { ServerStream - ServerStreamMetadata StreamSendMiddleware StreamRecvMiddleware } diff --git a/pkg/streamx/streamx_gen_service_test.go b/pkg/streamx/streamx_gen_service_test.go index 5c049ec79f..dbee613254 100644 --- a/pkg/streamx/streamx_gen_service_test.go +++ b/pkg/streamx/streamx_gen_service_test.go @@ -141,7 +141,7 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), ), }, - Extra: map[string]interface{}{"streamingFlag": true, "streamx": true}, + Extra: map[string]interface{}{"streaming": true, "streamx": true}, } // --- Define RegisterService interface --- diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index bccac99536..8f7e2d01f7 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -32,13 +32,15 @@ import ( "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/klog" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" + "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/streamxclient" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/streamxserver" "github.com/cloudwego/kitex/transport" @@ -58,7 +60,7 @@ func init() { providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_LongConn", ClientProvider: cp, ServerProvider: sp}) cp, _ = ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientShortConnPool()) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_ShortConn", ClientProvider: cp, ServerProvider: sp}) - cp, _ = ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientMuxConnPool()) + cp, _ = ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientMuxConnPool(ttstream.MuxConnConfig{PoolSize: 8, MaxIdleTimeout: time.Millisecond * 1000})) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_Mux", ClientProvider: cp, ServerProvider: sp}) } @@ -66,13 +68,16 @@ func TestMain(m *testing.M) { go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) }() - // klog.SetLevel(klog.LevelDebug) + klog.SetLevel(klog.LevelDebug) m.Run() } func TestStreamingBasic(t *testing.T) { for _, tc := range providerTestCases { t.Run(tc.Name, func(t *testing.T) { + concurrency := 100 + round := 5 + // === prepare test environment === addr := test.GetLocalAddress() ln, err := netpoll.CreateListener("tcp", addr) @@ -224,92 +229,110 @@ func TestStreamingBasic(t *testing.T) { octx := setMetadata(context.Background()) t.Logf("=== PingPong ===") - req := new(Request) - req.Message = "PingPong" - res, err := pingpongClient.PingPong(octx, req) - test.Assert(t, err == nil, err) - test.Assert(t, req.Message == res.Message, res) + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := new(Request) + req.Message = "PingPong" + res, err := pingpongClient.PingPong(octx, req) + test.Assert(t, err == nil, err) + test.Assert(t, req.Message == res.Message, res) + }() + } + wg.Wait() t.Logf("=== Unary ===") - req = new(Request) - req.Type = 10000 - req.Message = "Unary" - res, err = streamClient.Unary(octx, req) - test.Assert(t, err == nil, err) - test.Assert(t, req.Type == res.Type, res.Type) - test.Assert(t, req.Message == res.Message, res.Message) - atomic.AddInt32(&serverStreamCount, -1) + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := new(Request) + req.Type = 10000 + req.Message = "Unary" + res, err := streamClient.Unary(octx, req) + test.Assert(t, err == nil, err) + test.Assert(t, req.Type == res.Type, res.Type) + test.Assert(t, req.Message == res.Message, res.Message) + atomic.AddInt32(&serverStreamCount, -1) + }() + } + wg.Wait() waitServerStreamDone() - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(1)) - test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(1)) + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(concurrency)) atomic.StoreInt32(&serverRecvCount, 0) atomic.StoreInt32(&serverSendCount, 0) // client stream - round := 5 t.Logf("=== ClientStream ===") - ctx, cs, err := streamClient.ClientStream(octx) - test.Assert(t, err == nil, err) - for i := 0; i < round; i++ { - req := new(Request) - req.Type = int32(i) - req.Message = "ClientStream" - err = cs.Send(ctx, req) - test.Assert(t, err == nil, err) + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx, cs, err := streamClient.ClientStream(octx) + test.Assert(t, err == nil, err) + for i := 0; i < round; i++ { + req := new(Request) + req.Type = int32(i) + req.Message = "ClientStream" + err = cs.Send(ctx, req) + test.Assert(t, err == nil, err) + } + res, err := cs.CloseAndRecv(ctx) + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == "ClientStream", res.Message) + atomic.AddInt32(&serverStreamCount, -1) + testHeaderAndTrailer(t, cs) + }() } - res, err = cs.CloseAndRecv(ctx) - test.Assert(t, err == nil, err) - test.Assert(t, res.Message == "ClientStream", res.Message) - atomic.AddInt32(&serverStreamCount, -1) + wg.Wait() waitServerStreamDone() - testHeaderAndTrailer(t, cs) - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round)) - test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(1)) + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round)*int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(concurrency)) atomic.StoreInt32(&serverRecvCount, 0) atomic.StoreInt32(&serverSendCount, 0) - cs = nil - runtime.GC() // server stream t.Logf("=== ServerStream ===") - req = new(Request) - req.Message = "ServerStream" - ctx, ss, err := streamClient.ServerStream(octx, req) - test.Assert(t, err == nil, err) - received := 0 - for { - res, err := ss.Recv(ctx) - if errors.Is(err, io.EOF) { - break - } - test.Assert(t, err == nil, err) - received++ - t.Logf("Client ServerStream recv: %v", res) + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := new(Request) + req.Message = "ServerStream" + ctx, ss, err := streamClient.ServerStream(octx, req) + test.Assert(t, err == nil, err) + received := 0 + for { + res, err := ss.Recv(ctx) + if errors.Is(err, io.EOF) { + break + } + test.Assert(t, err == nil, err) + received++ + t.Logf("Client ServerStream recv: %v", res) + } + testHeaderAndTrailer(t, ss) + atomic.AddInt32(&serverStreamCount, -1) + }() } - err = ss.CloseSend(ctx) - test.Assert(t, err == nil, err) - atomic.AddInt32(&serverStreamCount, -1) + wg.Wait() waitServerStreamDone() - testHeaderAndTrailer(t, ss) - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(1)) - test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(received)) + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(concurrency)) atomic.StoreInt32(&serverRecvCount, 0) atomic.StoreInt32(&serverSendCount, 0) - ss = nil - runtime.GC() // bidi stream t.Logf("=== BidiStream ===") - concurrent := 32 - round = 5 - for c := 0; c < concurrent; c++ { - atomic.AddInt32(&serverStreamCount, -1) + for i := 0; i < concurrency; i++ { + wg.Add(3) go func() { + defer wg.Done() ctx, bs, err := streamClient.BidiStream(octx) test.Assert(t, err == nil, err) msg := "BidiStream" - var wg sync.WaitGroup - wg.Add(2) go func() { defer wg.Done() for i := 0; i < round; i++ { @@ -318,7 +341,7 @@ func TestStreamingBasic(t *testing.T) { err := bs.Send(ctx, req) test.Assert(t, err == nil, err) } - err = bs.CloseSend(ctx) + err := bs.CloseSend(ctx) test.Assert(t, err == nil, err) }() go func() { @@ -336,19 +359,20 @@ func TestStreamingBasic(t *testing.T) { test.Assert(t, i == round, i) }() testHeaderAndTrailer(t, bs) + atomic.AddInt32(&serverStreamCount, -1) }() } + wg.Wait() waitServerStreamDone() - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(concurrent*round)) - test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(concurrent*round)) + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round*concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(round*concurrency)) atomic.StoreInt32(&serverRecvCount, 0) atomic.StoreInt32(&serverSendCount, 0) - runtime.GC() t.Logf("=== UnaryWithErr normalErr ===") - req = new(Request) + req := new(Request) req.Type = normalErr - res, err = streamClient.UnaryWithErr(ctx, req) + res, err := streamClient.UnaryWithErr(octx, req) test.Assert(t, res == nil, res) test.Assert(t, err != nil, err) assertNormalErr(t, err) @@ -356,7 +380,7 @@ func TestStreamingBasic(t *testing.T) { t.Logf("=== UnaryWithErr bizErr ===") req = new(Request) req.Type = bizErr - res, err = streamClient.UnaryWithErr(ctx, req) + res, err = streamClient.UnaryWithErr(octx, req) test.Assert(t, res == nil, res) test.Assert(t, err != nil, err) assertBizErr(t, err) @@ -543,8 +567,9 @@ func TestStreamingGoroutineLeak(t *testing.T) { } t.Logf("=== Checking streams GCed ===") + ngBefore := runtime.NumGoroutine() streams := 100 - streamList := make([]streamx.ServerStream, streams) + streamList := make([]streamx.ClientStream, streams) atomic.StoreInt32(&streamStarted, 0) for i := 0; i < streams; i++ { _, bs, err := streamClient.BidiStream(octx) @@ -553,7 +578,6 @@ func TestStreamingGoroutineLeak(t *testing.T) { } waitStreamStarted(streams) // before GC - ngBefore := runtime.NumGoroutine() test.Assert(t, runtime.NumGoroutine() > streams, runtime.NumGoroutine()) // after GC for i := 0; i < streams; i++ { @@ -567,6 +591,7 @@ func TestStreamingGoroutineLeak(t *testing.T) { t.Logf("=== Checking Streams Called and GCed ===") streams = 100 + ngBefore = runtime.NumGoroutine() for i := 0; i < streams; i++ { wg.Add(1) go func() { @@ -589,6 +614,42 @@ func TestStreamingGoroutineLeak(t *testing.T) { }() } wg.Wait() + for runtime.NumGoroutine() > ngBefore { + t.Logf("ngCurrent=%d > ngBefore=%d", runtime.NumGoroutine(), ngBefore) + runtime.GC() + time.Sleep(time.Millisecond * 50) + } + + t.Logf("=== Checking Server Streaming ===") + streams = 100 + ngBefore = runtime.NumGoroutine() + for i := 0; i < streams; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + req := new(Request) + req.Message = msg + ctx, ss, err := streamClient.ServerStream(octx, req) + test.Assert(t, err == nil, err) + + for { + res, err := ss.Recv(ctx) + if err == io.EOF { + break + } + test.Assert(t, err == nil, err) + test.Assert(t, res.Message == msg, res.Message) + } + testHeaderAndTrailer(t, ss) + }() + } + wg.Wait() + for runtime.NumGoroutine() > ngBefore { + t.Logf("ngCurrent=%d > ngBefore=%d", runtime.NumGoroutine(), ngBefore) + runtime.GC() + time.Sleep(time.Millisecond * 50) + } }) } } @@ -633,13 +694,13 @@ func TestStreamingException(t *testing.T) { // assert circuitBreaker error atomic.StoreInt32(&circuitBreaker, 1) - ctx, bs, err := streamClient.BidiStream(octx) + _, _, err := streamClient.BidiStream(octx) test.Assert(t, errors.Is(err, circuitBreakerErr), err) atomic.StoreInt32(&circuitBreaker, 0) // assert context deadline error ctx, cancel := context.WithTimeout(octx, time.Millisecond) - ctx, bs, err = streamClient.BidiStream(ctx) + ctx, bs, err := streamClient.BidiStream(ctx) test.Assert(t, err == nil, err) res, err := bs.Recv(ctx) cancel() diff --git a/server/streamxserver/server_gen.go b/server/streamxserver/server_gen.go index c92c8692b9..74d47067af 100644 --- a/server/streamxserver/server_gen.go +++ b/server/streamxserver/server_gen.go @@ -38,7 +38,7 @@ func InvokeStream[Req, Res any]( return errors.New("server stream is nil") } shandler := handler.(streamx.StreamHandler) - gs := streamx.NewGenericServerStream[Req, Res](sArgs.Stream()) + gs := streamx.NewGenericServerStream[Req, Res](sArgs.Stream().(streamx.ServerStream)) gs.SetStreamRecvMiddleware(shandler.StreamRecvMiddleware) gs.SetStreamSendMiddleware(shandler.StreamSendMiddleware) From 68574c1bd153e592e41f28a776fed8ca61ab939c Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Mon, 4 Nov 2024 13:40:52 +0800 Subject: [PATCH 13/34] chore: using thrift payload mock --- pkg/streamx/provider/ttstream/mock_test.go | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/pkg/streamx/provider/ttstream/mock_test.go b/pkg/streamx/provider/ttstream/mock_test.go index c0996650d9..f4e367b4c3 100644 --- a/pkg/streamx/provider/ttstream/mock_test.go +++ b/pkg/streamx/provider/ttstream/mock_test.go @@ -17,9 +17,9 @@ package ttstream import ( - "encoding/json" "fmt" + "github.com/cloudwego/frugal" "github.com/cloudwego/kitex/pkg/protocol/bthrift" kutils "github.com/cloudwego/kitex/pkg/utils" ) @@ -30,22 +30,16 @@ type testRequest struct { } func (p *testRequest) FastRead(buf []byte) (int, error) { - err := json.Unmarshal(buf, p) - if err != nil { - return 0, err - } - return len(buf), nil + return frugal.DecodeObject(buf, p) } func (p *testRequest) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { - data, _ := json.Marshal(p) - copy(buf, data) - return len(data) + n, _ := frugal.EncodeObject(buf, binaryWriter, p) + return n } func (p *testRequest) BLength() int { - data, _ := json.Marshal(p) - return len(data) + return frugal.EncodedSize(p) } func (p *testRequest) DeepCopy(s interface{}) error { From de4ba2b9bf9921b7901f69e2c56256ea3597e7d7 Mon Sep 17 00:00:00 2001 From: Joway Date: Wed, 6 Nov 2024 13:42:20 +0800 Subject: [PATCH 14/34] chore: add transport buffer test (#1605) --- pkg/streamx/provider/jsonrpc/client_option.go | 25 --- .../provider/jsonrpc/client_provier.go | 63 ------ pkg/streamx/provider/jsonrpc/protocol.go | 184 ------------------ .../provider/jsonrpc/server_provider.go | 77 -------- pkg/streamx/provider/jsonrpc/stream.go | 138 ------------- pkg/streamx/provider/jsonrpc/transport.go | 174 ----------------- .../provider/jsonrpc/transport_test.go | 159 --------------- pkg/streamx/provider/ttstream/stream.go | 35 ++-- .../provider/ttstream/stream_reader.go | 2 +- .../stream_writer.go} | 11 +- pkg/streamx/provider/ttstream/transport.go | 54 ++--- .../provider/ttstream/transport_buffer.go | 2 - .../ttstream/transport_buffer_test.go | 79 ++++++++ .../provider/ttstream/transport_test.go | 10 + pkg/streamx/streamx_gen_service_test.go | 2 +- 15 files changed, 140 insertions(+), 875 deletions(-) delete mode 100644 pkg/streamx/provider/jsonrpc/client_option.go delete mode 100644 pkg/streamx/provider/jsonrpc/client_provier.go delete mode 100644 pkg/streamx/provider/jsonrpc/protocol.go delete mode 100644 pkg/streamx/provider/jsonrpc/server_provider.go delete mode 100644 pkg/streamx/provider/jsonrpc/stream.go delete mode 100644 pkg/streamx/provider/jsonrpc/transport.go delete mode 100644 pkg/streamx/provider/jsonrpc/transport_test.go rename pkg/streamx/provider/{jsonrpc/server_option.go => ttstream/stream_writer.go} (76%) create mode 100644 pkg/streamx/provider/ttstream/transport_buffer_test.go diff --git a/pkg/streamx/provider/jsonrpc/client_option.go b/pkg/streamx/provider/jsonrpc/client_option.go deleted file mode 100644 index 3a35f5c1f0..0000000000 --- a/pkg/streamx/provider/jsonrpc/client_option.go +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc - -type ClientProviderOption func(cp *clientProvider) - -func WithClientPayloadLimit(limit int) ClientProviderOption { - return func(cp *clientProvider) { - cp.payloadLimit = limit - } -} diff --git a/pkg/streamx/provider/jsonrpc/client_provier.go b/pkg/streamx/provider/jsonrpc/client_provier.go deleted file mode 100644 index 987206920d..0000000000 --- a/pkg/streamx/provider/jsonrpc/client_provier.go +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc - -import ( - "context" - "net" - - "github.com/cloudwego/kitex/pkg/kerrors" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/streamx" -) - -var _ streamx.ClientProvider = (*clientProvider)(nil) - -func NewClientProvider(sinfo *serviceinfo.ServiceInfo, opts ...ClientProviderOption) (streamx.ClientProvider, error) { - cp := new(clientProvider) - cp.sinfo = sinfo - for _, opt := range opts { - opt(cp) - } - return cp, nil -} - -type clientProvider struct { - sinfo *serviceinfo.ServiceInfo - payloadLimit int -} - -func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (streamx.ClientStream, error) { - invocation := ri.Invocation() - method := invocation.MethodName() - addr := ri.To().Address() - if addr == nil { - return nil, kerrors.ErrNoDestAddress - } - conn, err := net.Dial(addr.Network(), addr.String()) - if err != nil { - return nil, err - } - trans := newTransport(c.sinfo, conn) - s, err := trans.newStream(method) - if err != nil { - return nil, err - } - cs := newClientStream(s) - return cs, err -} diff --git a/pkg/streamx/provider/jsonrpc/protocol.go b/pkg/streamx/provider/jsonrpc/protocol.go deleted file mode 100644 index cbac87b90d..0000000000 --- a/pkg/streamx/provider/jsonrpc/protocol.go +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "io" - - "github.com/cloudwego/netpoll" -) - -/* JSON RPC Protocol - -=== client create a new stream === -- send {type=META, sid=1, service="a.b.c", method="xxx"} -- send {type=DATA, sid=1, service="a.b.c", method="xxx", payload="..."} -- recv {type=DATA, sid=1, service="a.b.c", method="xxx", payload="..."} - -=== server accept a new stream === -- recv {type=META, sid=1, service="a.b.c", method="xxx"} -- recv {type=DATA, sid=1, service="a.b.c", method="xxx", payload="..."} -- send {type=DATA, sid=1, service="a.b.c", method="xxx", payload="..."} - -=== client try to close stream === -- client send {type=EOF , sid=1, service="a.b.c", method="xxx"} -- server recv {type=EOF , sid=1, service="a.b.c", method="xxx"} -- server send {type=EOF , sid=1, service="a.b.c", method="xxx"}: server close stream -- client recv {type=EOF , sid=1, service="a.b.c", method="xxx"}: client close stream - -=== server try to close stream === -- server send {type=EOF , sid=1, service="a.b.c", method="xxx"} -- client recv {type=EOF , sid=1, service="a.b.c", method="xxx"} -- client send {type=EOF , sid=1, service="a.b.c", method="xxx"}: client close stream -- server recv {type=EOF , sid=1, service="a.b.c", method="xxx"}: server close stream -*/ - -const ( - frameMagic uint32 = 0x123321 - // meta: new stream - frameTypeMeta = 0 - // data: stream streamSend/streamRecv data - frameTypeData = 1 - // eof: stream closed by peer - frameTypeEOF = 2 -) - -// Frame define a JSON RPC protocol frame -// - 4 bytes: frameMagic -// - 4 bytes: data size -// - 4 bytes: frame kind -// - 4 bytes: stream id -// - 4 bytes: service name size -// - service_name_size bytes: service name -// - 4 bytes: method name size -// - method_name_size bytes: method name -// - ... bytes: json payload -type Frame struct { - typ int - sid int - service string - method string - payload []byte -} - -func newFrame(typ, sid int, service, method string, payload []byte) Frame { - return Frame{ - typ: typ, - sid: sid, - service: service, - method: method, - payload: payload, - } -} - -func EncodeFrame(writer io.Writer, frame Frame) (err error) { - // not include data size field length - dataSize := 4*4 + len(frame.service) + len(frame.method) + len(frame.payload) - data := make([]byte, 4+4+dataSize) - offset := 0 - - // header - binary.BigEndian.PutUint32(data[offset:offset+4], frameMagic) - offset += 4 - binary.BigEndian.PutUint32(data[offset:offset+4], uint32(dataSize)) - offset += 4 - - // data - binary.BigEndian.PutUint32(data[offset:offset+4], uint32(frame.typ)) - offset += 4 - binary.BigEndian.PutUint32(data[offset:offset+4], uint32(frame.sid)) - offset += 4 - binary.BigEndian.PutUint32(data[offset:offset+4], uint32(len(frame.service))) - offset += 4 - copy(data[offset:offset+len(frame.service)], frame.service) - offset += len(frame.service) - binary.BigEndian.PutUint32(data[offset:offset+4], uint32(len(frame.method))) - offset += 4 - copy(data[offset:offset+len(frame.method)], frame.method) - offset += len(frame.method) - copy(data[offset:offset+len(frame.payload)], frame.payload) - offset += len(frame.payload) - _ = offset - - idx := 0 - for idx < len(data) { - n, err := writer.Write(data[idx:]) - if err != nil { - return err - } - idx += n - } - return nil -} - -func EncodePayload(msg any) ([]byte, error) { - return json.Marshal(msg) -} - -func DecodeFrame(reader io.Reader) (frame Frame, err error) { - header := make([]byte, 8) - _, err = io.ReadFull(reader, header) - if err != nil { - return - } - magic := binary.BigEndian.Uint32(header[:4]) - size := binary.BigEndian.Uint32(header[4:8]) - if magic != frameMagic { - err = fmt.Errorf("invalid frame magic number: %d", magic) - return - } - - data := make([]byte, size) - _, err = io.ReadFull(reader, data) - if err != nil { - return - } - offset := 0 - frame.typ = int(binary.BigEndian.Uint32(data[offset : offset+4])) - offset += 4 - frame.sid = int(binary.BigEndian.Uint32(data[offset : offset+4])) - offset += 4 - serviceSize := int(binary.BigEndian.Uint32(data[offset : offset+4])) - offset += 4 - frame.service = string(data[offset : offset+serviceSize]) - offset += serviceSize - methodSize := int(binary.BigEndian.Uint32(data[offset : offset+4])) - offset += 4 - frame.method = string(data[offset : offset+methodSize]) - offset += methodSize - frame.payload = data[offset:] - return -} - -func DecodePayload(payload []byte, msg any) (err error) { - return json.Unmarshal(payload, msg) -} - -func checkFrame(conn netpoll.Connection) error { - header, err := conn.Reader().Peek(8) - if err != nil { - return err - } - magic := binary.BigEndian.Uint32(header[:4]) - if magic != frameMagic { - return fmt.Errorf("invalid frame magic number: %d", magic) - } - return nil -} diff --git a/pkg/streamx/provider/jsonrpc/server_provider.go b/pkg/streamx/provider/jsonrpc/server_provider.go deleted file mode 100644 index 6f74bde151..0000000000 --- a/pkg/streamx/provider/jsonrpc/server_provider.go +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc - -import ( - "context" - "net" - - "github.com/cloudwego/netpoll" - - "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/streamx" -) - -type serverTransCtxKey struct{} - -func NewServerProvider(sinfo *serviceinfo.ServiceInfo, opts ...ServerProviderOption) (streamx.ServerProvider, error) { - sp := new(serverProvider) - sp.sinfo = sinfo - for _, opt := range opts { - opt(sp) - } - return sp, nil -} - -var _ streamx.ServerProvider = (*serverProvider)(nil) - -type serverProvider struct { - sinfo *serviceinfo.ServiceInfo - payloadLimit int -} - -func (s serverProvider) Available(ctx context.Context, conn net.Conn) bool { - err := checkFrame(conn.(netpoll.Connection)) - return err == nil -} - -func (s serverProvider) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { - trans := newTransport(s.sinfo, conn) - return context.WithValue(ctx, serverTransCtxKey{}, trans), nil -} - -func (s serverProvider) OnInactive(ctx context.Context, conn net.Conn) (context.Context, error) { - return ctx, nil -} - -func (s serverProvider) OnStream(ctx context.Context, conn net.Conn) (context.Context, streamx.ServerStream, error) { - trans, _ := ctx.Value(serverTransCtxKey{}).(*transport) - if trans == nil { - return nil, nil, nil - } - st, err := trans.readStream() - if err != nil { - return nil, nil, err - } - ss := newServerStream(st) - return ctx, ss, nil -} - -func (s serverProvider) OnStreamFinish(ctx context.Context, ss streamx.ServerStream, err error) (context.Context, error) { - sst := ss.(*serverStream) - return ctx, sst.sendEOF() -} diff --git a/pkg/streamx/provider/jsonrpc/stream.go b/pkg/streamx/provider/jsonrpc/stream.go deleted file mode 100644 index 6e36b95d46..0000000000 --- a/pkg/streamx/provider/jsonrpc/stream.go +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc - -import ( - "context" - "log" - "sync/atomic" - - "github.com/cloudwego/kitex/pkg/streamx" -) - -var ( - _ streamx.ClientStream = (*clientStream)(nil) - _ streamx.ServerStream = (*serverStream)(nil) - _ streamx.ClientStreamMetadata = (*clientStream)(nil) - _ streamx.ServerStreamMetadata = (*serverStream)(nil) -) - -func newStream(trans *transport, sid int, mode streamx.StreamingMode, service, method string) (s *stream) { - s = new(stream) - s.id = sid - s.mode = mode - s.service = service - s.method = method - s.trans = trans - return s -} - -type stream struct { - id int - mode streamx.StreamingMode - service string - method string - selfEOF int32 - peerEOF int32 - trans *transport -} - -func (s *stream) Header() (streamx.Header, error) { - return make(streamx.Header), nil -} - -func (s *stream) Trailer() (streamx.Trailer, error) { - return make(streamx.Trailer), nil -} - -func (s *stream) Mode() streamx.StreamingMode { - return s.mode -} - -func (s *stream) Service() string { - return s.service -} - -func (s *stream) Method() string { - return s.method -} - -func (s *stream) sendEOF() (err error) { - if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { - return nil - } - log.Printf("stream[%s] send EOF", s.method) - return s.trans.streamCloseSend(s) -} - -func (s *stream) recvEOF() (err error) { - if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { - return nil - } - log.Printf("stream[%s] recv EOF", s.method) - return s.trans.streamCloseRecv(s) -} - -func (s *stream) SendMsg(ctx context.Context, res any) error { - payload, err := EncodePayload(res) - if err != nil { - return err - } - return s.trans.streamSend(s, payload) -} - -func (s *stream) RecvMsg(ctx context.Context, req any) error { - payload, err := s.trans.streamRecv(s) - if err != nil { - return err - } - return DecodePayload(payload, req) -} - -func newClientStream(s *stream) *clientStream { - cs := &clientStream{stream: s} - return cs -} - -type clientStream struct { - *stream -} - -func (s *clientStream) CloseSend(ctx context.Context) error { - return s.sendEOF() -} - -func newServerStream(s *stream) streamx.ServerStream { - ss := &serverStream{stream: s} - return ss -} - -type serverStream struct { - *stream -} - -func (s *serverStream) SetHeader(hd streamx.Header) error { - return nil -} - -func (s *serverStream) SendHeader(hd streamx.Header) error { - return nil -} - -func (s *serverStream) SetTrailer(hd streamx.Trailer) error { - return nil -} diff --git a/pkg/streamx/provider/jsonrpc/transport.go b/pkg/streamx/provider/jsonrpc/transport.go deleted file mode 100644 index 399fe3ba9d..0000000000 --- a/pkg/streamx/provider/jsonrpc/transport.go +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc - -import ( - "errors" - "fmt" - "io" - "net" - "sync" - "sync/atomic" - - "github.com/cloudwego/netpoll" - - "github.com/cloudwego/kitex/pkg/klog" - "github.com/cloudwego/kitex/pkg/serviceinfo" -) - -type transport struct { - sinfo *serviceinfo.ServiceInfo - conn net.Conn - streams sync.Map - sch chan *stream - rch map[int]chan Frame - wch chan Frame - stop chan struct{} -} - -func newTransport(sinfo *serviceinfo.ServiceInfo, conn net.Conn) *transport { - t := &transport{ - sinfo: sinfo, - conn: conn, - streams: sync.Map{}, - sch: make(chan *stream), - rch: map[int]chan Frame{}, - wch: make(chan Frame), - stop: make(chan struct{}), - } - go func() { - err := t.loopRead() - if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) && !errors.Is(err, netpoll.ErrConnClosed) { - klog.Debugf("transport loop read err: %v", err) - } - }() - go func() { - err := t.loopWrite() - if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { - klog.Debugf("transport loop write err: %v", err) - } - }() - return t -} - -func (t *transport) close() (err error) { - select { - case <-t.stop: - default: - close(t.stop) - } - return nil -} - -func (t *transport) streamSend(s *stream, payload []byte) (err error) { - f := newFrame(frameTypeData, s.id, s.service, s.method, payload) - t.wch <- f - return nil -} - -func (t *transport) streamRecv(s *stream) (payload []byte, err error) { - f := <-t.rch[s.id] - if f.sid != s.id { // f.sid == 0 means it's a empty frame - return nil, io.EOF - } - return f.payload, nil -} - -func (t *transport) loopRead() error { - for { - // decode frame - frame, err := DecodeFrame(t.conn) - if err != nil { - return err - } - - // prepare stream - switch frame.typ { - case frameTypeMeta: // new stream - smode := t.sinfo.MethodInfo(frame.method).StreamingMode() - s := newStream(t, frame.sid, smode, frame.service, frame.method) - t.streams.Store(s.id, s) - t.rch[s.id] = make(chan Frame, 1024) - t.sch <- s - case frameTypeData, frameTypeEOF: // stream streamRecv/close - iss, ok := t.streams.Load(frame.sid) - if !ok { - return fmt.Errorf("stream not found in stream map: sid=%d", frame.sid) - } - s := iss.(*stream) - switch frame.typ { - case frameTypeEOF: - err = s.recvEOF() - return err - case frameTypeData: - // process data frame - t.rch[s.id] <- frame - } - } - } -} - -func (t *transport) loopWrite() error { - for { - select { - case <-t.stop: - return nil - case frame := <-t.wch: - err := EncodeFrame(t.conn, frame) - if err != nil { - return err - } - } - } -} - -var clientStreamID uint32 - -func (t *transport) newStream(method string) (*stream, error) { - sid := int(atomic.AddUint32(&clientStreamID, 1)) - smode := t.sinfo.MethodInfo(method).StreamingMode() - service := t.sinfo.ServiceName - f := newFrame(frameTypeMeta, sid, service, method, []byte{}) - s := newStream(t, sid, smode, service, method) - t.streams.Store(s.id, s) - t.rch[s.id] = make(chan Frame, 1024) - t.wch <- f // create stream - return s, nil -} - -func (t *transport) streamCloseRecv(s *stream) (err error) { - //for len(t.rch[s.id]) > 0 { - // runtime.Gosched() - //} - close(t.rch[s.id]) - return nil -} - -func (t *transport) streamCloseSend(s *stream) (err error) { - f := newFrame(frameTypeEOF, s.id, s.service, s.method, []byte("EOF")) - t.wch <- f - return nil -} - -func (t *transport) readStream() (*stream, error) { - select { - case <-t.stop: - return nil, io.EOF - case s := <-t.sch: - return s, nil - } -} diff --git a/pkg/streamx/provider/jsonrpc/transport_test.go b/pkg/streamx/provider/jsonrpc/transport_test.go deleted file mode 100644 index 187e15e7c8..0000000000 --- a/pkg/streamx/provider/jsonrpc/transport_test.go +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package jsonrpc - -import ( - "bufio" - "bytes" - "context" - "errors" - "io" - "net" - "sync/atomic" - "testing" - "time" - - "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/serviceinfo" -) - -func TestCodec(t *testing.T) { - var buf bytes.Buffer - writer := bufio.NewWriter(&buf) - f1 := newFrame(0, 1, "a.b.c", "test", []byte("12345")) - err := EncodeFrame(writer, f1) - test.Assert(t, err == nil, err) - _ = writer.Flush() - reader := bufio.NewReader(&buf) - f2, err := DecodeFrame(reader) - test.Assert(t, err == nil, err) - test.Assert(t, f2.method == f1.method, f2.method) - test.Assert(t, string(f2.payload) == string(f1.payload), f2.payload) -} - -func TestTransport(t *testing.T) { - type TestRequest struct { - A int `json:"A,omitempty"` - B string `json:"B,omitempty"` - } - type TestResponse = TestRequest - method := "BidiStream" - sinfo := &serviceinfo.ServiceInfo{ - ServiceName: "a.b.c", - Methods: map[string]serviceinfo.MethodInfo{ - method: serviceinfo.NewMethodInfo( - func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return nil - }, - nil, - nil, - false, - serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), - ), - }, - Extra: map[string]interface{}{"streaming": true}, - } - - addr := test.GetLocalAddress() - ln, err := net.Listen("tcp", addr) - test.Assert(t, err == nil, err) - - // Server - var connDone int32 - var streamDone int32 - go func() { - for { - conn, err := ln.Accept() - if (conn == nil && err == nil) || errors.Is(err, net.ErrClosed) { - return - } - test.Assert(t, err == nil, err) - go func() { - defer atomic.AddInt32(&connDone, -1) - server := newTransport(sinfo, conn) - st, nerr := server.readStream() - if nerr != nil { - if nerr == io.EOF { - return - } - t.Error(nerr) - } - go func() { - defer atomic.AddInt32(&streamDone, -1) - for { - ctx := context.Background() - req := new(TestRequest) - nerr = st.RecvMsg(ctx, req) - if errors.Is(nerr, io.EOF) { - return - } - t.Logf("server recv msg: %v %v", req, nerr) - res := req - nerr = st.SendMsg(ctx, res) - t.Logf("server send msg: %v %v", res, nerr) - if nerr != nil { - if nerr == io.EOF { - return - } - t.Error(nerr) - } - } - }() - }() - } - }() - time.Sleep(time.Millisecond * 100) - - // Client - atomic.AddInt32(&connDone, 1) - conn, err := net.Dial("tcp", addr) - test.Assert(t, err == nil, err) - trans := newTransport(sinfo, conn) - s, err := trans.newStream(method) - test.Assert(t, err == nil, err) - cs := newClientStream(s) - - req := new(TestRequest) - req.A = 12345 - req.B = "hello" - res := new(TestResponse) - ctx := context.Background() - err = cs.SendMsg(ctx, req) - t.Logf("client send msg: %v", req) - test.Assert(t, err == nil, err) - err = cs.RecvMsg(ctx, res) - t.Logf("client recv msg: %v", res) - test.Assert(t, err == nil, err) - test.Assert(t, req.A == res.A, res) - test.Assert(t, req.B == res.B, res) - - // close stream - err = cs.CloseSend(ctx) - test.Assert(t, err == nil, err) - for atomic.LoadInt32(&streamDone) != 0 { - time.Sleep(time.Millisecond * 10) - } - - // close conn - err = trans.close() - test.Assert(t, err == nil, err) - err = ln.Close() - test.Assert(t, err == nil, err) - for atomic.LoadInt32(&connDone) != 0 { - time.Sleep(time.Millisecond * 10) - } -} diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 4b4f21ac52..0f4e9d7755 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -42,12 +42,12 @@ var ( _ StreamMeta = (*stream)(nil) ) -func newStream(ctx context.Context, trans *transport, mode streamx.StreamingMode, smeta streamFrame) *stream { +func newStream(ctx context.Context, writer streamWriter, mode streamx.StreamingMode, smeta streamFrame) *stream { s := new(stream) s.streamFrame = smeta s.StreamMeta = newStreamMeta() s.reader = newStreamReader() - s.trans = trans + s.writer = writer s.mode = mode s.wheader = make(streamx.Header) s.wtrailer = make(streamx.Trailer) @@ -83,7 +83,7 @@ type stream struct { streamFrame StreamMeta reader *streamReader - trans *transport + writer streamWriter mode streamx.StreamingMode wheader streamx.Header // wheader == nil means it already be sent wtrailer streamx.Trailer // wtrailer == nil means it already be sent @@ -117,14 +117,14 @@ func (s *stream) Method() string { func (s *stream) SendMsg(ctx context.Context, msg any) (err error) { if atomic.LoadInt32(&s.selfEOF) != 0 { - return terrors.ErrIllegalOperation.WithCause(errors.New("stream is close send")) + return terrors.ErrIllegalOperation.WithCause(errors.New("stream is closed send")) } // encode payload payload, err := EncodePayload(ctx, msg) if err != nil { return err } - // tracing + // tracing send size ri := rpcinfo.GetRPCInfo(ctx) if ri != nil && ri.Stats() != nil { if rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()); rpcStats != nil { @@ -149,7 +149,7 @@ func (s *stream) RecvMsg(ctx context.Context, data any) error { // payload will not be access after decode mcache.Free(payload) - // tracing + // tracing recv size ri := rpcinfo.GetRPCInfo(ctx) if ri != nil && ri.Stats() != nil { if rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()); rpcStats != nil { @@ -236,14 +236,12 @@ func (s *stream) tryRunCloseCallback() { cb() } } - s.trans.deleteStream(s.sid) - s.trans.recycle() + _ = s.writer.CloseStream(s.sid) } func (s *stream) writeFrame(ftype int32, header streamx.Header, trailer streamx.Trailer, payload []byte) (err error) { - return s.trans.writeFrame( - streamFrame{sid: s.sid, method: s.method, header: header, trailer: trailer}, ftype, payload, - ) + fr := newFrame(streamFrame{sid: s.sid, method: s.method, header: header, trailer: trailer}, ftype, payload) + return s.writer.WriteFrame(fr) } // writeHeader copy kvs into s.wheader @@ -287,7 +285,7 @@ func (s *stream) sendTrailer(exception error) (err error) { if wtrailer == nil { return fmt.Errorf("stream trailer already sent") } - klog.Debugf("transport[%d]-stream[%d] send trailer: err=%v", s.trans.kind, s.sid, exception) + klog.Debugf("stream[%d] send trailer: err=%v", s.sid, exception) var payload []byte if exception != nil { @@ -359,14 +357,9 @@ func (s *stream) onReadTrailerFrame(fr *Frame) (err error) { } klog.Debugf("stream[%d] recv trailer: %v, exception: %v", s.sid, s.trailer, exception) - switch s.trans.kind { - case clientTransport: - // if client recv trailer, server handler must be return, - // so we don't need to send data anymore - err = s.closeRecv(exception) - case serverTransport: - // if server recv trailer, we only need to close recv but still can send data - err = s.closeRecv(exception) - } + // if client recv trailer, server handler must be return, + // so we don't need to send data anymore + // if server recv trailer, we only need to close recv but still can send data + err = s.closeRecv(exception) return err } diff --git a/pkg/streamx/provider/ttstream/stream_reader.go b/pkg/streamx/provider/ttstream/stream_reader.go index fe23ba7086..91daf2e349 100644 --- a/pkg/streamx/provider/ttstream/stream_reader.go +++ b/pkg/streamx/provider/ttstream/stream_reader.go @@ -46,7 +46,7 @@ func newStreamReader() *streamReader { func (s *streamReader) input(ctx context.Context, payload []byte) { err := s.pipe.Write(ctx, streamMsg{payload: payload}) if err != nil { - klog.Errorf("pipe write failed: %v", err) + klog.Errorf("stream pipe input failed: %v", err) } } diff --git a/pkg/streamx/provider/jsonrpc/server_option.go b/pkg/streamx/provider/ttstream/stream_writer.go similarity index 76% rename from pkg/streamx/provider/jsonrpc/server_option.go rename to pkg/streamx/provider/ttstream/stream_writer.go index 0cc2ae4017..398867733e 100644 --- a/pkg/streamx/provider/jsonrpc/server_option.go +++ b/pkg/streamx/provider/ttstream/stream_writer.go @@ -14,12 +14,11 @@ * limitations under the License. */ -package jsonrpc +package ttstream -type ServerProviderOption func(pc *serverProvider) +var _ streamWriter = (*transport)(nil) -func WithServerPayloadLimit(limit int) ServerProviderOption { - return func(s *serverProvider) { - s.payloadLimit = limit - } +type streamWriter interface { + WriteFrame(f *Frame) error + CloseStream(sid int32) error } diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index ad4f84c0cd..cb59816960 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -85,8 +85,6 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne if err != nil { if !isIgnoreError(err) { klog.Warnf("transport[%d-%s] loop read err: %v", t.kind, t.Addr(), err) - } else { - klog.Debugf("transport[%d-%s] loop read err: %v", t.kind, t.Addr(), err) } // if connection is closed by peer, loop read should return ErrConnClosed error, // so we should close transport here @@ -101,8 +99,6 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne if err != nil { if !isIgnoreError(err) { klog.Warnf("transport[%d-%s] loop write err: %v", t.kind, t.Addr(), err) - } else { - klog.Debugf("transport[%d-%s] loop write err: %v", t.kind, t.Addr(), err) } _ = t.Close(err) } @@ -170,12 +166,6 @@ func (t *transport) deleteStream(sid int32) { t.streams.Delete(sid) } -func (t *transport) recycle() { - if t.pool != nil { - t.pool.Put(t) - } -} - func (t *transport) readFrame(reader bufiox.Reader) error { fr, err := DecodeFrame(context.Background(), reader) if err != nil { @@ -196,10 +186,7 @@ func (t *transport) readFrame(reader bufiox.Reader) error { var ok bool s, ok = t.loadStream(fr.sid) if !ok { - klog.Errorf( - "transport[%d] read a unknown stream: frame[%s]", - t.kind, fr.String(), - ) + klog.Errorf("transport[%d] read a unknown stream: frame[%s]", t.kind, fr.String()) // ignore unknown stream error err = nil } else { @@ -253,7 +240,7 @@ func (t *transport) loopWrite() error { } for i := 0; i < n; i++ { fr := fcache[i] - klog.Debugf("transport[%d] EncodeFrame: fr=%s", t.kind, fr) + klog.Debugf("transport[%d] EncodeFrame: frame=%s", t.kind, fr) if err = EncodeFrame(context.Background(), writer, fr); err != nil { return err } @@ -262,20 +249,40 @@ func (t *transport) loopWrite() error { } recycleFrame(fr) } - if err = t.conn.Writer().Flush(); err != nil { + if err = writer.Flush(); err != nil { return err } } } -// writeFrame is concurrent safe -func (t *transport) writeFrame(sframe streamFrame, ftype int32, payload []byte) (err error) { - frame := newFrame(sframe, ftype, payload) - return t.fpipe.Write(context.Background(), frame) +// WriteFrame is concurrent safe +func (t *transport) WriteFrame(fr *Frame) (err error) { + return t.fpipe.Write(context.Background(), fr) +} + +func (t *transport) CloseStream(sid int32) (err error) { + t.deleteStream(sid) + // clientTransport may require to return the transport to transPool + if t.pool != nil { + t.pool.Put(t) + } + return nil } var clientStreamID int32 +// stream id can be negative +func genStreamID() int32 { + // here have a really rare case that one connection get two same stream id when exist (2*max_int32) streams, + // but it just happens in theory because in real world, no service can process soo many streams in the same time. + sid := atomic.AddInt32(&clientStreamID, 1) + // we preserve streamId=0 for connection level control frame in the future. + if sid == 0 { + sid = atomic.AddInt32(&clientStreamID, 1) + } + return sid +} + // WriteStream create new stream on current connection // it's typically used by client side // newStream is concurrency safe @@ -286,15 +293,14 @@ func (t *transport) WriteStream( return nil, fmt.Errorf("transport already be used as other kind") } - sid := atomic.AddInt32(&clientStreamID, 1) + sid := genStreamID() smode := t.sinfo.MethodInfo(method).StreamingMode() // new stream first s := newStream(ctx, t, smode, streamFrame{sid: sid, method: method}) t.storeStream(s) // send create stream request for server - err := t.writeFrame( - streamFrame{sid: sid, method: method, header: strHeader, meta: intHeader}, headerFrameType, nil, - ) + fr := newFrame(streamFrame{sid: sid, method: method, header: strHeader, meta: intHeader}, headerFrameType, nil) + err := t.WriteFrame(fr) if err != nil { return nil, err } diff --git a/pkg/streamx/provider/ttstream/transport_buffer.go b/pkg/streamx/provider/ttstream/transport_buffer.go index 36620fac46..312eb9ccea 100644 --- a/pkg/streamx/provider/ttstream/transport_buffer.go +++ b/pkg/streamx/provider/ttstream/transport_buffer.go @@ -129,8 +129,6 @@ func (c *writerBuffer) WrittenLen() (length int) { func (c *writerBuffer) Flush() (err error) { err = c.writer.Flush() - c.writer = nil c.writeSize = 0 - writerBufferPool.Put(c) return err } diff --git a/pkg/streamx/provider/ttstream/transport_buffer_test.go b/pkg/streamx/provider/ttstream/transport_buffer_test.go new file mode 100644 index 0000000000..62f5d0c48b --- /dev/null +++ b/pkg/streamx/provider/ttstream/transport_buffer_test.go @@ -0,0 +1,79 @@ +//go:build !windows + +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/netpoll" +) + +func TestTransportBuffer(t *testing.T) { + rfd, wfd := netpoll.GetSysFdPairs() + rconn, err := netpoll.NewFDConnection(rfd) + test.Assert(t, err == nil, err) + wconn, err := netpoll.NewFDConnection(wfd) + test.Assert(t, err == nil, err) + + rbuf := newReaderBuffer(rconn.Reader()) + wbuf := newWriterBuffer(wconn.Writer()) + msg := make([]byte, 1024) + for i := 0; i < len(msg); i++ { + msg[i] = 'a' + byte(i%26) + } + + // test Malloc + buf, err := wbuf.Malloc(len(msg)) + test.Assert(t, err == nil, err) + test.DeepEqual(t, wbuf.WrittenLen(), len(msg)) + copy(buf, msg) + err = wbuf.Flush() + test.Assert(t, err == nil, err) + test.DeepEqual(t, wbuf.WrittenLen(), 0) + + // test ReadBinary + buf = make([]byte, len(msg)) + n, err := rbuf.ReadBinary(buf) + test.Assert(t, err == nil, err) + test.DeepEqual(t, buf, msg) + test.DeepEqual(t, n, len(msg)) + test.DeepEqual(t, rbuf.ReadLen(), len(msg)) + err = rbuf.Release(nil) + test.Assert(t, err == nil, err) + test.DeepEqual(t, rbuf.ReadLen(), 0) + + // test WriteBinary + n, err = wbuf.WriteBinary(msg) + test.Assert(t, err == nil, err) + test.DeepEqual(t, n, len(msg)) + test.DeepEqual(t, wbuf.WrittenLen(), len(msg)) + err = wbuf.Flush() + test.Assert(t, err == nil, err) + test.DeepEqual(t, wbuf.WrittenLen(), 0) + + // test Next + buf, err = rbuf.Next(len(msg)) + test.Assert(t, err == nil, err) + test.DeepEqual(t, buf, msg) + test.DeepEqual(t, rbuf.ReadLen(), len(msg)) + err = rbuf.Release(nil) + test.Assert(t, err == nil, err) + test.DeepEqual(t, rbuf.ReadLen(), 0) +} diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index 444206a00d..68d4e5ecf0 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -22,8 +22,10 @@ import ( "context" "errors" "io" + "math" "strings" "sync" + "sync/atomic" "testing" "github.com/cloudwego/gopkg/protocol/thrift" @@ -228,3 +230,11 @@ func TestTransportException(t *testing.T) { test.Assert(t, err != nil, err) t.Logf("client stream send msg: %v %v", err, errors.Is(err, terrors.ErrIllegalFrame)) } + +func TestStreamID(t *testing.T) { + atomic.StoreInt32(&clientStreamID, math.MaxInt32-1) + id := genStreamID() + test.Assert(t, id == math.MaxInt32) + id = genStreamID() + test.Assert(t, id == math.MinInt32) +} diff --git a/pkg/streamx/streamx_gen_service_test.go b/pkg/streamx/streamx_gen_service_test.go index dbee613254..ccc3eece72 100644 --- a/pkg/streamx/streamx_gen_service_test.go +++ b/pkg/streamx/streamx_gen_service_test.go @@ -145,7 +145,7 @@ var streamingServiceInfo = &serviceinfo.ServiceInfo{ } // --- Define RegisterService interface --- -func RegisterService(svr server.Server, handler StreamingServerInterface, opts ...server.RegisterOption) error { +func RegisterStreamingService(svr server.Server, handler StreamingServerInterface, opts ...server.RegisterOption) error { return svr.RegisterService(streamingServiceInfo, handler, opts...) } From 3afbd88f25233d06ee3d4f5040f0aa3138ec3586 Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Wed, 6 Nov 2024 16:09:56 +0800 Subject: [PATCH 15/34] feat: TTHeader Streaming errors inherit from kerrors.ErrStreamingProtocol (#1603) --- pkg/kerrors/kerrors_test.go | 2 + pkg/kerrors/streaming_errors.go | 21 +++++++ .../ttstream/client_trans_pool_muxconn.go | 2 +- .../ttstream/client_trans_pool_shortconn.go | 2 +- .../provider/ttstream/errors/errors.go | 52 ---------------- pkg/streamx/provider/ttstream/frame.go | 2 +- pkg/streamx/provider/ttstream/stream.go | 2 +- .../provider/ttstream/terrors/terrors.go | 60 +++++++++++++++++++ .../terrors_test.go} | 18 +++++- .../provider/ttstream/transport_test.go | 2 +- pkg/streamx/streamx_common_test.go | 3 +- 11 files changed, 107 insertions(+), 59 deletions(-) create mode 100644 pkg/kerrors/streaming_errors.go delete mode 100644 pkg/streamx/provider/ttstream/errors/errors.go create mode 100644 pkg/streamx/provider/ttstream/terrors/terrors.go rename pkg/streamx/provider/ttstream/{errors/errors_test.go => terrors/terrors_test.go} (70%) diff --git a/pkg/kerrors/kerrors_test.go b/pkg/kerrors/kerrors_test.go index 1cd65cb257..f8505befb7 100644 --- a/pkg/kerrors/kerrors_test.go +++ b/pkg/kerrors/kerrors_test.go @@ -47,6 +47,8 @@ func TestIsKitexError(t *testing.T) { ErrNoMoreInstance, ErrConnOverLimit, ErrQPSOverLimit, + // streaming errors + ErrStreamingProtocol, } for _, e := range errs { test.Assert(t, IsKitexError(e)) diff --git a/pkg/kerrors/streaming_errors.go b/pkg/kerrors/streaming_errors.go new file mode 100644 index 0000000000..4233dd7fbe --- /dev/null +++ b/pkg/kerrors/streaming_errors.go @@ -0,0 +1,21 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kerrors + +// ErrStreamingProtocol is the parent type of all streaming protocol(e.g. gRPC, TTHeader Streaming) +// related but not user-aware errors. +var ErrStreamingProtocol = &basicError{"streaming protocol error"} diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go index 25082b3c7c..e372ad3931 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go @@ -28,7 +28,7 @@ import ( "golang.org/x/sync/singleflight" "github.com/cloudwego/kitex/pkg/serviceinfo" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) var DefaultMuxConnConfig = MuxConnConfig{ diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go index 71e15e94ed..d5bd7e463f 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go @@ -21,7 +21,7 @@ import ( "time" "github.com/cloudwego/kitex/pkg/serviceinfo" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) func newShortConnTransPool() transPool { diff --git a/pkg/streamx/provider/ttstream/errors/errors.go b/pkg/streamx/provider/ttstream/errors/errors.go deleted file mode 100644 index 7d75e331ad..0000000000 --- a/pkg/streamx/provider/ttstream/errors/errors.go +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package errors - -import ( - "errors" -) - -var ( - ErrUnexpectedHeader = &errType{message: "unexpected header frame"} - ErrUnexpectedTrailer = &errType{message: "unexpected trailer frame"} - ErrApplicationException = &errType{message: "application exception"} - ErrIllegalBizErr = &errType{message: "illegal bizErr"} - ErrIllegalFrame = &errType{message: "illegal frame"} - ErrIllegalOperation = &errType{message: "illegal operation"} - ErrTransport = &errType{message: "transport is closing"} -) - -type errType struct { - message string - basic error - cause error -} - -func (e *errType) WithCause(err error) error { - return &errType{message: e.message, basic: e, cause: err} -} - -func (e *errType) Error() string { - if e.cause == nil { - return e.message - } - return "[" + e.message + "] " + e.cause.Error() -} - -func (e *errType) Is(target error) bool { - return target == e || target == e.basic || errors.Is(e.cause, target) -} diff --git a/pkg/streamx/provider/ttstream/frame.go b/pkg/streamx/provider/ttstream/frame.go index 54169c0bd0..a0f5fdcfdc 100644 --- a/pkg/streamx/provider/ttstream/frame.go +++ b/pkg/streamx/provider/ttstream/frame.go @@ -32,7 +32,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/streamx" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) const ( diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 0f4e9d7755..7e0cca2c03 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -32,7 +32,7 @@ import ( "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" "github.com/cloudwego/kitex/pkg/transmeta" ) diff --git a/pkg/streamx/provider/ttstream/terrors/terrors.go b/pkg/streamx/provider/ttstream/terrors/terrors.go new file mode 100644 index 0000000000..769ffc2287 --- /dev/null +++ b/pkg/streamx/provider/ttstream/terrors/terrors.go @@ -0,0 +1,60 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package terrors + +import ( + "errors" + + "github.com/cloudwego/kitex/pkg/kerrors" +) + +// terrors define TTHeader Streaming-related protocol errors, they all inherit from ErrStreamingProtocol in kerrors. +var ( + ErrUnexpectedHeader = newErrType("unexpected header frame") + ErrApplicationException = newErrType("application exception") + ErrIllegalBizErr = newErrType("illegal bizErr") + ErrIllegalFrame = newErrType("illegal frame") + ErrIllegalOperation = newErrType("illegal operation") + ErrTransport = newErrType("transport is closing") +) + +type errType struct { + message string + // parent errType + basic error + // detailed err + cause error +} + +func newErrType(message string) *errType { + return &errType{message: message, basic: kerrors.ErrStreamingProtocol} +} + +func (e *errType) WithCause(err error) error { + return &errType{basic: e, cause: err} +} + +func (e *errType) Error() string { + if e.cause == nil { + return e.message + } + return "[" + e.basic.Error() + "] " + e.cause.Error() +} + +func (e *errType) Is(target error) bool { + return target == e || errors.Is(e.basic, target) || errors.Is(e.cause, target) +} diff --git a/pkg/streamx/provider/ttstream/errors/errors_test.go b/pkg/streamx/provider/ttstream/terrors/terrors_test.go similarity index 70% rename from pkg/streamx/provider/ttstream/errors/errors_test.go rename to pkg/streamx/provider/ttstream/terrors/terrors_test.go index d926d3786d..191e038a32 100644 --- a/pkg/streamx/provider/ttstream/errors/errors_test.go +++ b/pkg/streamx/provider/ttstream/terrors/terrors_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package errors +package terrors import ( "errors" @@ -23,12 +23,28 @@ import ( "testing" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" ) func TestErrors(t *testing.T) { causeErr := fmt.Errorf("test1") newErr := ErrIllegalFrame.WithCause(causeErr) test.Assert(t, errors.Is(newErr, ErrIllegalFrame), newErr) + test.Assert(t, errors.Is(newErr, kerrors.ErrStreamingProtocol), newErr) test.Assert(t, strings.Contains(newErr.Error(), ErrIllegalFrame.Error())) test.Assert(t, strings.Contains(newErr.Error(), causeErr.Error())) } + +func TestCommonParentKerror(t *testing.T) { + errs := []error{ + ErrUnexpectedHeader, + ErrApplicationException, + ErrIllegalBizErr, + ErrIllegalFrame, + ErrIllegalOperation, + ErrTransport, + } + for _, err := range errs { + test.Assert(t, errors.Is(err, kerrors.ErrStreamingProtocol), err) + } +} diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index 68d4e5ecf0..50296ee58f 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -35,7 +35,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" "github.com/cloudwego/kitex/server/streamxserver" ) diff --git a/pkg/streamx/streamx_common_test.go b/pkg/streamx/streamx_common_test.go index 48907feea1..b1ff6af8d9 100644 --- a/pkg/streamx/streamx_common_test.go +++ b/pkg/streamx/streamx_common_test.go @@ -25,7 +25,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/errors" + terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) const ( @@ -81,6 +81,7 @@ func validateMetadata(ctx context.Context) bool { func assertNormalErr(t *testing.T, err error) { test.Assert(t, errors.Is(err, terrors.ErrApplicationException), err) + test.Assert(t, errors.Is(err, kerrors.ErrStreamingProtocol), err) } func assertBizErr(t *testing.T, err error) { From 8e476c504b0a93454b97b0748c3a5e1bfb3261d3 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Thu, 7 Nov 2024 19:26:41 +0800 Subject: [PATCH 16/34] refactor: using new user api --- client/client_streamx.go | 15 +- client/streamxclient/client_gen.go | 16 +- client/streamxclient/client_option.go | 27 +- go.mod | 2 +- go.sum | 2 + internal/server/option.go | 4 +- internal/server/register_option.go | 9 +- .../server/streamx_config.go | 20 +- pkg/remote/option.go | 6 + pkg/remote/remotecli/stream.go | 1 + pkg/remote/trans/streamx/server_handler.go | 14 +- pkg/serviceinfo/serviceinfo.go | 13 + .../ttstream/client_trans_pool_muxconn.go | 3 + pkg/streamx/provider/ttstream/mock_test.go | 1 + pkg/streamx/provider/ttstream/transport.go | 3 - .../provider/ttstream/transport_buffer.go | 6 +- .../ttstream/transport_buffer_test.go | 3 +- .../provider/ttstream/transport_test.go | 42 +- pkg/streamx/stream_args.go | 28 +- .../streamx/stream_handler.go | 26 +- pkg/streamx/stream_middleware.go | 11 - pkg/streamx/streamx_common_test.go | 2 +- pkg/streamx/streamx_gen_service_test.go | 142 +++--- pkg/streamx/streamx_user_service_test.go | 27 +- pkg/streamx/streamx_user_test.go | 441 ++++++++++-------- server/server.go | 37 +- server/service.go | 29 +- server/stream.go | 21 +- server/streamxserver/option.go | 48 ++ server/streamxserver/server_gen.go | 219 +++++---- server/streamxserver/server_option.go | 66 --- 31 files changed, 690 insertions(+), 594 deletions(-) rename server/streamxserver/server.go => internal/server/streamx_config.go (65%) rename client/streamxclient/client.go => pkg/streamx/stream_handler.go (51%) create mode 100644 server/streamxserver/option.go delete mode 100644 server/streamxserver/server_option.go diff --git a/client/client_streamx.go b/client/client_streamx.go index a0eb34829f..ac1369a847 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -25,16 +25,11 @@ import ( ) type StreamX interface { - NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) - Middlewares() (streamMW streamx.StreamMiddleware, recvMW streamx.StreamRecvMiddleware, sendMW streamx.StreamSendMiddleware) -} - -func (kc *kClient) Middlewares() (streamMW streamx.StreamMiddleware, recvMW streamx.StreamRecvMiddleware, sendMW streamx.StreamSendMiddleware) { - return kc.sxStreamMW, kc.sxStreamRecvMW, kc.sxStreamSendMW + NewStream(ctx context.Context, method string, req any, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) } // NewStream create stream for streamx mode -func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) { +func (kc *kClient) NewStream(ctx context.Context, method string, req any, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) { if !kc.inited { panic("client not initialized") } @@ -61,7 +56,11 @@ func (kc *kClient) NewStream(ctx context.Context, method string, req any, callOp })) copts.Apply(callOptions) - streamArgs := streamx.NewStreamArgs(nil) + if msargs := streamx.AsMutableStreamArgs(streamArgs); msargs != nil { + msargs.SetStreamMiddleware(kc.sxStreamMW) + msargs.SetStreamRecvMiddleware(kc.sxStreamRecvMW) + msargs.SetStreamSendMiddleware(kc.sxStreamSendMW) + } // put streamArgs into response arg // it's an ugly trick but if we don't want to refactor too much, // this is the only way to compatible with current endpoint design diff --git a/client/streamxclient/client_gen.go b/client/streamxclient/client_gen.go index 6c1d4af974..a5ff65b64c 100644 --- a/client/streamxclient/client_gen.go +++ b/client/streamxclient/client_gen.go @@ -40,16 +40,20 @@ func InvokeStream[Req, Res any]( resArgs.SetRes(res) } - ctx, cs, err := cli.NewStream(ctx, method, req, callOptions...) + // NewStream should register client middlewares into stream Args + ctx, cs, err := cli.NewStream(ctx, method, req, streamArgs, callOptions...) if err != nil { return nil, nil, err } stream := streamx.NewGenericClientStream[Req, Res](cs) - streamx.AsMutableStreamArgs(streamArgs).SetStream(stream) - - streamMW, recvMW, sendMW := cli.Middlewares() - stream.SetStreamRecvMiddleware(recvMW) - stream.SetStreamSendMiddleware(sendMW) + var streamMW streamx.StreamMiddleware + if streamMWs, ok := streamArgs.(streamx.StreamMiddlewaresArgs); ok { + var recvMW streamx.StreamRecvMiddleware + var sendMW streamx.StreamSendMiddleware + streamMW, recvMW, sendMW = streamMWs.Middlewares() + stream.SetStreamRecvMiddleware(recvMW) + stream.SetStreamSendMiddleware(sendMW) + } streamInvoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { // assemble streaming args depend on each stream mode diff --git a/client/streamxclient/client_option.go b/client/streamxclient/client_option.go index 0534b305a3..134f0a7483 100644 --- a/client/streamxclient/client_option.go +++ b/client/streamxclient/client_option.go @@ -19,31 +19,22 @@ package streamxclient import ( "time" - "github.com/cloudwego/kitex/client" internal_client "github.com/cloudwego/kitex/internal/client" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/utils" ) -type Option internal_client.Option +type Option = internal_client.Option -func WithHostPorts(hostports ...string) Option { - return ConvertNativeClientOption(client.WithHostPorts(hostports...)) -} - -func WithRecvTimeout(timeout time.Duration) Option { +func WithProvider(pvd streamx.ClientProvider) Option { return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.StreamXOptions.RecvTimeout = timeout + o.RemoteOpt.Provider = pvd }} } -func WithDestService(destService string) Option { - return ConvertNativeClientOption(client.WithDestService(destService)) -} - -func WithProvider(pvd streamx.ClientProvider) Option { +func WithStreamRecvTimeout(timeout time.Duration) Option { return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.RemoteOpt.Provider = pvd + o.StreamXOptions.RecvTimeout = timeout }} } @@ -64,11 +55,3 @@ func WithStreamSendMiddleware(smw streamx.StreamSendMiddleware) Option { o.StreamXOptions.StreamSendMWs = append(o.StreamXOptions.StreamSendMWs, smw) }} } - -func ConvertNativeClientOption(o internal_client.Option) Option { - return Option{F: o.F} -} - -func ConvertStreamXClientOption(o Option) internal_client.Option { - return internal_client.Option{F: o.F} -} diff --git a/go.mod b/go.mod index bdb558128e..030a562a8e 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/cloudwego/dynamicgo v0.4.6-0.20241115162834-0e99bc39b128 github.com/cloudwego/fastpb v0.0.5 github.com/cloudwego/frugal v0.2.0 - github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4 + github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682 github.com/cloudwego/localsession v0.1.1 github.com/cloudwego/netpoll v0.6.5-0.20240911104114-8a1f5597a920 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index 77d85246ad..81e9cf5cb6 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4 h1:SHw9GUBBcAnLWeK2MtPH7O6YQG9Q2ZZ8koD/4alpLvE= github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= +github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682 h1:hj/AhlEngERp5Tjt864veEvyK6RglXKcXpxkIOSRfug= +github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/localsession v0.1.1 h1:tbK7laDVrYfFDXoBXo4uCGMAxU4qmz2dDm8d4BGBnDo= diff --git a/internal/server/option.go b/internal/server/option.go index 5c2a8ffb46..7bf3f21b7f 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -96,7 +96,9 @@ type Options struct { BackupOpt backup.Options - Streaming stream.StreamingConfig + // Streaming + Streaming stream.StreamingConfig // old version streaming API config + StreamX StreamXConfig // new version streaming API config RefuseTrafficWithoutServiceName bool EnableContextTimeout bool diff --git a/internal/server/register_option.go b/internal/server/register_option.go index 6cf884ac70..5d3c7579fe 100644 --- a/internal/server/register_option.go +++ b/internal/server/register_option.go @@ -18,7 +18,6 @@ package server import ( "github.com/cloudwego/kitex/pkg/endpoint" - "github.com/cloudwego/kitex/pkg/streamx" ) // RegisterOption is the only way to config service registration. @@ -28,12 +27,8 @@ type RegisterOption struct { // RegisterOptions is used to config service registration. type RegisterOptions struct { - IsFallbackService bool - Middlewares []endpoint.Middleware - StreamMiddlewares []streamx.StreamMiddleware - StreamRecvMiddlewares []streamx.StreamRecvMiddleware - StreamSendMiddlewares []streamx.StreamSendMiddleware - Provider streamx.ServerProvider + IsFallbackService bool + Middlewares []endpoint.Middleware } // NewRegisterOptions creates a register options. diff --git a/server/streamxserver/server.go b/internal/server/streamx_config.go similarity index 65% rename from server/streamxserver/server.go rename to internal/server/streamx_config.go index adfeb73432..9890b9cdb4 100644 --- a/server/streamxserver/server.go +++ b/internal/server/streamx_config.go @@ -14,19 +14,13 @@ * limitations under the License. */ -package streamxserver +package server -import ( - "github.com/cloudwego/kitex/server" -) +import "github.com/cloudwego/kitex/pkg/streamx" -type Server = server.Server - -func NewServer(opts ...Option) server.Server { - iopts := make([]server.Option, 0, len(opts)+1) - for _, opt := range opts { - iopts = append(iopts, ConvertStreamXServerOption(opt)) - } - s := server.NewServer(iopts...) - return s +type StreamXConfig struct { + StreamMiddlewares []streamx.StreamMiddleware + StreamRecvMiddlewares []streamx.StreamRecvMiddleware + StreamSendMiddlewares []streamx.StreamSendMiddleware + Provider streamx.ServerProvider } diff --git a/pkg/remote/option.go b/pkg/remote/option.go index 93a7ec416d..573b17eb8a 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -27,6 +27,7 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/pkg/streamx" ) // Option is used to pack the inbound and outbound handlers. @@ -122,6 +123,11 @@ type ServerOption struct { // for thrift streaming, this is enabled by default // for grpc(protobuf) streaming, it's disabled by default, enable with server.WithCompatibleMiddlewareForUnary CompatibleMiddlewareForUnary bool + + // for streamx middlewares + StreamMiddleware streamx.StreamMiddleware + StreamRecvMiddleware streamx.StreamRecvMiddleware + StreamSendMiddleware streamx.StreamSendMiddleware } // ClientOption is used to init the remote client. diff --git a/pkg/remote/remotecli/stream.go b/pkg/remote/remotecli/stream.go index 6bf541329d..3ecc05c0aa 100644 --- a/pkg/remote/remotecli/stream.go +++ b/pkg/remote/remotecli/stream.go @@ -36,6 +36,7 @@ func NewStream(ctx context.Context, ri rpcinfo.RPCInfo, handler remote.ClientTra } } + // streamx provider clientProvider, ok := opt.Provider.(streamx.ClientProvider) if ok { // wrap client provider diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index b4408bad43..5681365690 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -33,14 +33,14 @@ import ( "github.com/cloudwego/kitex/pkg/streamx" ) -/* 实际上 remote.ServerTransHandler 真正被 trans_server.go 使用的接口只有: +/* trans_server.go only use the following interface in remote.ServerTransHandler: - OnRead - OnActive - OnInactive - OnError - GracefulShutdown: assert 方式使用 -其他接口实际上最终是用来去组装了 transpipeline .... +Other interface is used by trans pipeline */ var streamWorkerPool = wpool.New(128, time.Second) @@ -134,11 +134,13 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) // - process server stream // - close server stream func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss streamx.ServerStream) (err error) { - // inkHdlFunc 包含了所有中间件 + 用户 serviceInfo.methodHandler - // 这里 streamx 依然会复用原本的 server endpoint.Endpoint 中间件,因为他们都不会单独去取 req/res 的值 - // 无法在保留现有 streaming 功能的情况下,彻底弃用 endpoint.Endpoint , 所以这里依然使用 endpoint 接口 - // 但是对用户 API ,做了单独的封装。把这部分脏逻辑仅暴露在框架中。 + // inkHdlFunc includes all middlewares and serviceInfo.methodHandler + // streamx still reuse the server endpoint.Endpoint and all server level middlewares sargs := streamx.NewStreamArgs(ss) + msargs := streamx.AsMutableStreamArgs(sargs) + msargs.SetStreamRecvMiddleware(t.opt.StreamRecvMiddleware) + msargs.SetStreamSendMiddleware(t.opt.StreamSendMiddleware) + msargs.SetStreamMiddleware(t.opt.StreamMiddleware) ctx = streamx.WithStreamArgsContext(ctx, sargs) ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr()) diff --git a/pkg/serviceinfo/serviceinfo.go b/pkg/serviceinfo/serviceinfo.go index 592c492d12..76f2dff2fc 100644 --- a/pkg/serviceinfo/serviceinfo.go +++ b/pkg/serviceinfo/serviceinfo.go @@ -117,6 +117,7 @@ type MethodInfo interface { OneWay() bool IsStreaming() bool StreamingMode() StreamingMode + Extra() map[string]string } // MethodHandler is corresponding to the handler wrapper func that in generated code @@ -133,6 +134,12 @@ func WithStreamingMode(mode StreamingMode) MethodInfoOption { } } +func WithMethodExtra(k, v string) MethodInfoOption { + return func(m *methodInfo) { + m.extra[k] = v + } +} + // NewMethodInfo is called in generated code to build method info func NewMethodInfo(methodHandler MethodHandler, newArgsFunc, newResultFunc func() interface{}, oneWay bool, opts ...MethodInfoOption) MethodInfo { mi := methodInfo{ @@ -142,6 +149,7 @@ func NewMethodInfo(methodHandler MethodHandler, newArgsFunc, newResultFunc func( oneWay: oneWay, isStreaming: false, streamingMode: StreamingNone, + extra: make(map[string]string), } for _, opt := range opts { opt(&mi) @@ -156,6 +164,7 @@ type methodInfo struct { oneWay bool isStreaming bool streamingMode StreamingMode + extra map[string]string } // Handler implements the MethodInfo interface. @@ -186,6 +195,10 @@ func (m methodInfo) StreamingMode() StreamingMode { return m.streamingMode } +func (m methodInfo) Extra() map[string]string { + return m.extra +} + // String prints human-readable information. func (p PayloadCodec) String() string { switch p { diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go index e372ad3931..ef44e3a07b 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go @@ -66,6 +66,9 @@ func newMuxConnTransList(size int, pool transPool) *muxConnTransList { func (tl *muxConnTransList) Close() { tl.L.Lock() for i, t := range tl.transports { + if t == nil { + continue + } _ = t.Close(nil) tl.transports[i] = nil } diff --git a/pkg/streamx/provider/ttstream/mock_test.go b/pkg/streamx/provider/ttstream/mock_test.go index f4e367b4c3..27701cf4b2 100644 --- a/pkg/streamx/provider/ttstream/mock_test.go +++ b/pkg/streamx/provider/ttstream/mock_test.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/cloudwego/frugal" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" kutils "github.com/cloudwego/kitex/pkg/utils" ) diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index cb59816960..c329629db4 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -244,9 +244,6 @@ func (t *transport) loopWrite() error { if err = EncodeFrame(context.Background(), writer, fr); err != nil { return err } - if err = t.conn.Writer().Flush(); err != nil { - return err - } recycleFrame(fr) } if err = writer.Flush(); err != nil { diff --git a/pkg/streamx/provider/ttstream/transport_buffer.go b/pkg/streamx/provider/ttstream/transport_buffer.go index 312eb9ccea..e697fc4891 100644 --- a/pkg/streamx/provider/ttstream/transport_buffer.go +++ b/pkg/streamx/provider/ttstream/transport_buffer.go @@ -118,9 +118,8 @@ func (c *writerBuffer) WriteBinary(bs []byte) (n int, err error) { } func (c *writerBuffer) WriteDirect(b []byte, remainCap int) (err error) { - err = c.writer.WriteDirect(b, remainCap) c.writeSize += len(b) - return err + return c.writer.WriteDirect(b, remainCap) } func (c *writerBuffer) WrittenLen() (length int) { @@ -128,7 +127,6 @@ func (c *writerBuffer) WrittenLen() (length int) { } func (c *writerBuffer) Flush() (err error) { - err = c.writer.Flush() c.writeSize = 0 - return err + return c.writer.Flush() } diff --git a/pkg/streamx/provider/ttstream/transport_buffer_test.go b/pkg/streamx/provider/ttstream/transport_buffer_test.go index 62f5d0c48b..74251a9200 100644 --- a/pkg/streamx/provider/ttstream/transport_buffer_test.go +++ b/pkg/streamx/provider/ttstream/transport_buffer_test.go @@ -21,8 +21,9 @@ package ttstream import ( "testing" - "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/internal/test" ) func TestTransportBuffer(t *testing.T) { diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index 50296ee58f..a994d0a99c 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -27,6 +27,7 @@ import ( "sync" "sync/atomic" "testing" + "time" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/netpoll" @@ -44,8 +45,12 @@ var testServiceInfo = &serviceinfo.ServiceInfo{ Methods: map[string]serviceinfo.MethodInfo{ "Bidi": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[testRequest, testResponse]( - ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeBidiStreamHandler[testRequest, testResponse]( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, stream streamx.BidiStreamingServer[testRequest, testResponse]) error { + return nil + }, + ) }, nil, nil, @@ -53,7 +58,6 @@ var testServiceInfo = &serviceinfo.ServiceInfo{ serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), ), }, - Extra: map[string]interface{}{"streaming": true, "streamx": true}, } func TestTransportBasic(t *testing.T) { @@ -198,24 +202,36 @@ func TestTransportException(t *testing.T) { sconn, err := netpoll.NewFDConnection(sfd) test.Assert(t, err == nil, err) + // server send data ctrans := newTransport(clientTransport, testServiceInfo, cconn, nil) rawClientStream, err := ctrans.WriteStream(context.Background(), "Bidi", make(IntHeader), make(streamx.Header)) test.Assert(t, err == nil, err) strans := newTransport(serverTransport, testServiceInfo, sconn, nil) rawServerStream, err := strans.ReadStream(context.Background()) test.Assert(t, err == nil, err) + cStream := newClientStream(rawClientStream) + sStream := newServerStream(rawServerStream) + res := new(testResponse) + res.A = 123 + err = sStream.SendMsg(context.Background(), res) + test.Assert(t, err == nil, err) + res = new(testResponse) + err = cStream.RecvMsg(context.Background(), res) + test.Assert(t, err == nil, err) + test.Assert(t, res.A == 123, res) // server send exception - ss := newServerStream(rawServerStream) targetException := thrift.NewApplicationException(remote.InternalError, "test") - err = ss.CloseSend(targetException) + err = sStream.CloseSend(targetException) test.Assert(t, err == nil, err) // client recv exception - cs := newClientStream(rawClientStream) - res := new(testResponse) - err = cs.RecvMsg(context.Background(), res) + res = new(testResponse) + err = cStream.RecvMsg(context.Background(), res) test.Assert(t, err != nil, err) test.Assert(t, strings.Contains(err.Error(), targetException.Msg()), err.Error()) + err = cStream.CloseSend(context.Background()) + test.Assert(t, err == nil, err) + time.Sleep(time.Millisecond * 50) // server send illegal frame rawClientStream, err = ctrans.WriteStream(context.Background(), "Bidi", make(IntHeader), make(streamx.Header)) @@ -223,12 +239,14 @@ func TestTransportException(t *testing.T) { rawServerStream, err = strans.ReadStream(context.Background()) test.Assert(t, err == nil, err) test.Assert(t, rawServerStream != nil, rawServerStream) - _, err = sconn.Write([]byte("helloxxxxxxxxxxxxxxxxxxxxxx")) + _, err = sconn.Writer().WriteBinary([]byte("helloxxxxxxxxxxxxxxxxxxxxxx")) + test.Assert(t, err == nil, err) + err = sconn.Writer().Flush() test.Assert(t, err == nil, err) - cs = newClientStream(rawClientStream) - err = cs.RecvMsg(context.Background(), res) + cStream = newClientStream(rawClientStream) + err = cStream.RecvMsg(context.Background(), res) test.Assert(t, err != nil, err) - t.Logf("client stream send msg: %v %v", err, errors.Is(err, terrors.ErrIllegalFrame)) + test.Assert(t, errors.Is(err, terrors.ErrIllegalFrame), err) } func TestStreamID(t *testing.T) { diff --git a/pkg/streamx/stream_args.go b/pkg/streamx/stream_args.go index 432ba26d58..9592edfd2c 100644 --- a/pkg/streamx/stream_args.go +++ b/pkg/streamx/stream_args.go @@ -41,6 +41,10 @@ type StreamArgs interface { Stream() Stream } +type StreamMiddlewaresArgs interface { + Middlewares() (StreamMiddleware, StreamRecvMiddleware, StreamSendMiddleware) +} + func AsStream(args interface{}) (Stream, error) { s, ok := args.(StreamArgs) if !ok { @@ -51,6 +55,9 @@ func AsStream(args interface{}) (Stream, error) { type MutableStreamArgs interface { SetStream(st Stream) + SetStreamMiddleware(mw StreamMiddleware) + SetStreamRecvMiddleware(mw StreamRecvMiddleware) + SetStreamSendMiddleware(mw StreamSendMiddleware) } func AsMutableStreamArgs(args StreamArgs) MutableStreamArgs { @@ -62,17 +69,36 @@ func AsMutableStreamArgs(args StreamArgs) MutableStreamArgs { } type streamArgs struct { - stream Stream + stream Stream + streamMW StreamMiddleware + recvMW StreamRecvMiddleware + sendMW StreamSendMiddleware } func (s *streamArgs) Stream() Stream { return s.stream } +func (s *streamArgs) Middlewares() (StreamMiddleware, StreamRecvMiddleware, StreamSendMiddleware) { + return s.streamMW, s.recvMW, s.sendMW +} + func (s *streamArgs) SetStream(st Stream) { s.stream = st } +func (s *streamArgs) SetStreamMiddleware(mw StreamMiddleware) { + s.streamMW = mw +} + +func (s *streamArgs) SetStreamRecvMiddleware(mw StreamRecvMiddleware) { + s.recvMW = mw +} + +func (s *streamArgs) SetStreamSendMiddleware(mw StreamSendMiddleware) { + s.sendMW = mw +} + func NewStreamArgs(stream Stream) StreamArgs { return &streamArgs{stream: stream} } diff --git a/client/streamxclient/client.go b/pkg/streamx/stream_handler.go similarity index 51% rename from client/streamxclient/client.go rename to pkg/streamx/stream_handler.go index 09f52707b3..c52549069b 100644 --- a/client/streamxclient/client.go +++ b/pkg/streamx/stream_handler.go @@ -14,25 +14,15 @@ * limitations under the License. */ -package streamxclient +package streamx import ( - "github.com/cloudwego/kitex/client" - iclient "github.com/cloudwego/kitex/internal/client" - "github.com/cloudwego/kitex/pkg/serviceinfo" + "context" ) -type Client = client.StreamX - -func NewClient(svcInfo *serviceinfo.ServiceInfo, opts ...Option) (Client, error) { - iopts := make([]client.Option, 0, len(opts)+1) - for _, opt := range opts { - iopts = append(iopts, ConvertStreamXClientOption(opt)) - } - nopts := iclient.NewOptions(iopts) - c, err := client.NewClientWithOptions(svcInfo, nopts) - if err != nil { - return nil, err - } - return c.(client.StreamX), nil -} +type ( + UnaryHandler[Req, Res any] func(ctx context.Context, req *Req) (*Res, error) + ClientStreamingHandler[Req, Res any] func(ctx context.Context, stream ClientStreamingServer[Req, Res]) (*Res, error) + ServerStreamingHandler[Req, Res any] func(ctx context.Context, req *Req, stream ServerStreamingServer[Res]) error + BidiStreamingHandler[Req, Res any] func(ctx context.Context, stream BidiStreamingServer[Req, Res]) error +) diff --git a/pkg/streamx/stream_middleware.go b/pkg/streamx/stream_middleware.go index bce62d65a7..078df1fbf0 100644 --- a/pkg/streamx/stream_middleware.go +++ b/pkg/streamx/stream_middleware.go @@ -20,24 +20,13 @@ import ( "context" ) -type StreamHandler struct { - Handler any - StreamMiddleware StreamMiddleware - StreamRecvMiddleware StreamRecvMiddleware - StreamSendMiddleware StreamSendMiddleware -} - type ( StreamEndpoint func(ctx context.Context, streamArgs StreamArgs, reqArgs StreamReqArgs, resArgs StreamResArgs) (err error) StreamMiddleware func(next StreamEndpoint) StreamEndpoint -) -type ( StreamRecvEndpoint func(ctx context.Context, stream Stream, res any) (err error) StreamSendEndpoint func(ctx context.Context, stream Stream, req any) (err error) -) -type ( StreamRecvMiddleware func(next StreamRecvEndpoint) StreamRecvEndpoint StreamSendMiddleware func(next StreamSendEndpoint) StreamSendEndpoint ) diff --git a/pkg/streamx/streamx_common_test.go b/pkg/streamx/streamx_common_test.go index b1ff6af8d9..30091a9cc9 100644 --- a/pkg/streamx/streamx_common_test.go +++ b/pkg/streamx/streamx_common_test.go @@ -25,7 +25,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" - terrors "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) const ( diff --git a/pkg/streamx/streamx_gen_service_test.go b/pkg/streamx/streamx_gen_service_test.go index ccc3eece72..c267d6afd4 100644 --- a/pkg/streamx/streamx_gen_service_test.go +++ b/pkg/streamx/streamx_gen_service_test.go @@ -24,7 +24,6 @@ import ( "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/streamxserver" ) @@ -32,16 +31,16 @@ import ( // === gen code === // --- Define Service Method handler --- -var pingpongServiceInfo = &serviceinfo.ServiceInfo{ - ServiceName: "kitex.service.pingpong", +var testServiceInfo = &serviceinfo.ServiceInfo{ + ServiceName: "kitex.echo.service", PayloadCodec: serviceinfo.Thrift, - HandlerType: (*PingPongServerInterface)(nil), + HandlerType: (*TestService)(nil), Methods: map[string]serviceinfo.MethodInfo{ "PingPong": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { realArg := reqArgs.(*ServerPingPongArgs) realResult := resArgs.(*ServerPingPongResult) - success, err := handler.(PingPongServerInterface).PingPong(ctx, realArg.Req) + success, err := handler.(TestService).PingPong(ctx, realArg.Req) if err != nil { return err } @@ -53,137 +52,157 @@ var pingpongServiceInfo = &serviceinfo.ServiceInfo{ false, serviceinfo.WithStreamingMode(serviceinfo.StreamingNone), ), - }, - Extra: map[string]interface{}{"streaming": false}, -} - -var streamingServiceInfo = &serviceinfo.ServiceInfo{ - ServiceName: "kitex.service.streaming", - Methods: map[string]serviceinfo.MethodInfo{ "Unary": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingUnary, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeUnaryHandler[Request, Response]( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, req *Request) (*Response, error) { + return handler.(TestService).Unary(ctx, req) + }, + ) }, nil, nil, false, serviceinfo.WithStreamingMode(serviceinfo.StreamingUnary), + serviceinfo.WithMethodExtra("streamx", "true"), ), "ClientStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingClient, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeClientStreamHandler[Request, Response]( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response]) (*Response, error) { + return handler.(TestService).ClientStream(ctx, stream) + }, + ) }, nil, nil, false, serviceinfo.WithStreamingMode(serviceinfo.StreamingClient), + serviceinfo.WithMethodExtra("streamx", "true"), ), "ServerStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingServer, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeServerStreamHandler( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error { + return handler.(TestService).ServerStream(ctx, req, stream) + }, + ) }, nil, nil, false, serviceinfo.WithStreamingMode(serviceinfo.StreamingServer), + serviceinfo.WithMethodExtra("streamx", "true"), ), "BidiStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeBidiStreamHandler[Request, Response]( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, stream streamx.BidiStreamingServer[Request, Response]) error { + return handler.(TestService).BidiStream(ctx, stream) + }, + ) }, nil, nil, false, serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + serviceinfo.WithMethodExtra("streamx", "true"), ), "UnaryWithErr": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingUnary, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeUnaryHandler[Request, Response]( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, req *Request) (*Response, error) { + return handler.(TestService).UnaryWithErr(ctx, req) + }, + ) }, nil, nil, false, serviceinfo.WithStreamingMode(serviceinfo.StreamingUnary), + serviceinfo.WithMethodExtra("streamx", "true"), ), "ClientStreamWithErr": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingClient, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeClientStreamHandler[Request, Response]( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response]) (*Response, error) { + return handler.(TestService).ClientStreamWithErr(ctx, stream) + }, + ) }, nil, nil, false, serviceinfo.WithStreamingMode(serviceinfo.StreamingClient), + serviceinfo.WithMethodExtra("streamx", "true"), ), "ServerStreamWithErr": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingServer, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeServerStreamHandler[Request, Response]( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error { + return handler.(TestService).ServerStreamWithErr(ctx, req, stream) + }, + ) }, nil, nil, false, serviceinfo.WithStreamingMode(serviceinfo.StreamingServer), + serviceinfo.WithMethodExtra("streamx", "true"), ), "BidiStreamWithErr": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[Request, Response]( - ctx, serviceinfo.StreamingBidirectional, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) + return streamxserver.InvokeBidiStreamHandler[Request, Response]( + ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), + func(ctx context.Context, stream streamx.BidiStreamingServer[Request, Response]) error { + return handler.(TestService).BidiStreamWithErr(ctx, stream) + }, + ) }, nil, nil, false, serviceinfo.WithStreamingMode(serviceinfo.StreamingBidirectional), + serviceinfo.WithMethodExtra("streamx", "true"), ), }, - Extra: map[string]interface{}{"streaming": true, "streamx": true}, -} - -// --- Define RegisterService interface --- -func RegisterStreamingService(svr server.Server, handler StreamingServerInterface, opts ...server.RegisterOption) error { - return svr.RegisterService(streamingServiceInfo, handler, opts...) } -// --- Define New Client interface --- -func NewPingPongClient(destService string, opts ...client.Option) (PingPongClientInterface, error) { - var options []client.Option - options = append(options, client.WithDestService(destService)) +// --- Define NewServer interface --- +func NewServer(handler TestService, opts ...server.Option) server.Server { + var options []server.Option options = append(options, opts...) - cli, err := client.NewClient(pingpongServiceInfo, options...) - if err != nil { - return nil, err + svr := server.NewServer(options...) + if err := svr.RegisterService(testServiceInfo, handler); err != nil { + panic(err) } - kc := &kClient{caller: cli} - return kc, nil + return svr } -func NewStreamingClient(destService string, opts ...streamxclient.Option) (StreamingClientInterface, error) { - var options []streamxclient.Option - options = append(options, streamxclient.WithDestService(destService)) - cp, err := ttstream.NewClientProvider(streamingServiceInfo) - if err != nil { - return nil, err - } - options = append(options, streamxclient.WithProvider(cp)) +// --- Define NewClient interface --- +func NewClient(destService string, opts ...client.Option) (TestServiceClient, error) { + var options []client.Option + options = append(options, client.WithDestService(destService)) options = append(options, opts...) - cli, err := streamxclient.NewClient(streamingServiceInfo, options...) + cli, err := client.NewClient(testServiceInfo, options...) if err != nil { return nil, err } - kc := &kClient{streamer: cli, caller: cli.(client.Client)} + kc := &kClient{caller: cli, streamer: cli.(client.StreamX)} return kc, nil } // --- Define Server Implementation Interface --- -type PingPongServerInterface interface { +type TestService interface { PingPong(ctx context.Context, req *Request) (*Response, error) -} -type StreamingServerInterface interface { + Unary(ctx context.Context, req *Request) (*Response, error) ClientStream(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response]) (*Response, error) ServerStream(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error @@ -195,11 +214,9 @@ type StreamingServerInterface interface { } // --- Define Client Implementation Interface --- -type PingPongClientInterface interface { +type TestServiceClient interface { PingPong(ctx context.Context, req *Request) (r *Response, err error) -} -type StreamingClientInterface interface { Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( context.Context, streamx.ClientStreamingClient[Request, Response], error) @@ -217,14 +234,11 @@ type StreamingClientInterface interface { } // --- Define Client Implementation --- -var ( - _ StreamingClientInterface = (*kClient)(nil) - _ PingPongClientInterface = (*kClient)(nil) -) +var _ TestServiceClient = (*kClient)(nil) type kClient struct { caller client.Client - streamer streamxclient.Client + streamer client.StreamX } func (c *kClient) PingPong(ctx context.Context, req *Request) (r *Response, err error) { diff --git a/pkg/streamx/streamx_user_service_test.go b/pkg/streamx/streamx_user_service_test.go index f701a31dc6..c48dc5f212 100644 --- a/pkg/streamx/streamx_user_service_test.go +++ b/pkg/streamx/streamx_user_service_test.go @@ -29,10 +29,7 @@ import ( "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" ) -type ( - pingpongService struct{} - streamingService struct{} -) +type testService struct{} const ( headerKey = "header1" @@ -50,7 +47,7 @@ func testHeaderAndTrailer(t *testing.T, stream streamx.ClientStreamMetadata) { test.Assert(t, tl[trailerKey] == trailerVal, tl) } -func (si *streamingService) setHeaderAndTrailer(stream streamx.ServerStreamMetadata) error { +func (si *testService) setHeaderAndTrailer(stream streamx.ServerStreamMetadata) error { err := stream.SetTrailer(streamx.Trailer{trailerKey: trailerVal}) if err != nil { return err @@ -63,19 +60,19 @@ func (si *streamingService) setHeaderAndTrailer(stream streamx.ServerStreamMetad return nil } -func (si *pingpongService) PingPong(ctx context.Context, req *Request) (*Response, error) { +func (si *testService) PingPong(ctx context.Context, req *Request) (*Response, error) { resp := &Response{Type: req.Type, Message: req.Message} klog.Infof("Server PingPong: req={%v} resp={%v}", req, resp) return resp, nil } -func (si *streamingService) Unary(ctx context.Context, req *Request) (*Response, error) { +func (si *testService) Unary(ctx context.Context, req *Request) (*Response, error) { resp := &Response{Type: req.Type, Message: req.Message} klog.Infof("Server Unary: req={%v} resp={%v}", req, resp) return resp, nil } -func (si *streamingService) ClientStream(ctx context.Context, +func (si *testService) ClientStream(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response], ) (*Response, error) { var msg string @@ -100,7 +97,7 @@ func (si *streamingService) ClientStream(ctx context.Context, } } -func (si *streamingService) ServerStream(ctx context.Context, req *Request, +func (si *testService) ServerStream(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response], ) error { klog.Infof("Server ServerStream: req={%v}", req) @@ -109,7 +106,7 @@ func (si *streamingService) ServerStream(ctx context.Context, req *Request, return err } - for i := 0; i < 3; i++ { + for i := 0; i < 5; i++ { resp := new(Response) resp.Type = int32(i) resp.Message = req.Message @@ -122,7 +119,7 @@ func (si *streamingService) ServerStream(ctx context.Context, req *Request, return nil } -func (si *streamingService) BidiStream(ctx context.Context, +func (si *testService) BidiStream(ctx context.Context, stream streamx.BidiStreamingServer[Request, Response], ) error { ktx.RegisterCancelCallback(ctx, func() { @@ -166,13 +163,13 @@ func buildErr(req *Request) error { return err } -func (si *streamingService) UnaryWithErr(ctx context.Context, req *Request) (*Response, error) { +func (si *testService) UnaryWithErr(ctx context.Context, req *Request) (*Response, error) { err := buildErr(req) klog.Infof("Server UnaryWithErr: req={%v} err={%v}", req, err) return nil, err } -func (si *streamingService) ClientStreamWithErr(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response]) (res *Response, err error) { +func (si *testService) ClientStreamWithErr(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response]) (res *Response, err error) { req, err := stream.Recv(ctx) if err != nil { klog.Errorf("Server ClientStreamWithErr Recv failed, exception={%v}", err) @@ -183,13 +180,13 @@ func (si *streamingService) ClientStreamWithErr(ctx context.Context, stream stre return nil, err } -func (si *streamingService) ServerStreamWithErr(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error { +func (si *testService) ServerStreamWithErr(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error { err := buildErr(req) klog.Infof("Server ServerStreamWithErr: req={%v} err={%v}", req, err) return err } -func (si *streamingService) BidiStreamWithErr(ctx context.Context, stream streamx.BidiStreamingServer[Request, Response]) error { +func (si *testService) BidiStreamWithErr(ctx context.Context, stream streamx.BidiStreamingServer[Request, Response]) error { req, err := stream.Recv(ctx) if err != nil { klog.Errorf("Server BidiStreamWithErr Recv failed, exception={%v}", err) diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index 8f7e2d01f7..d9bf8a68d6 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -32,18 +32,19 @@ import ( "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/transport" + "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" - "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/streamxclient" "github.com/cloudwego/kitex/internal/test" - "github.com/cloudwego/kitex/pkg/remote/codec/thrift" - "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/streamxserver" - "github.com/cloudwego/kitex/transport" ) var providerTestCases []testCase @@ -55,12 +56,12 @@ type testCase struct { } func init() { - sp, _ := ttstream.NewServerProvider(streamingServiceInfo) - cp, _ := ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientLongConnPool(ttstream.LongConnConfig{MaxIdleTimeout: time.Millisecond * 100})) + sp, _ := ttstream.NewServerProvider(testServiceInfo) + cp, _ := ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientLongConnPool(ttstream.LongConnConfig{MaxIdleTimeout: time.Millisecond * 100})) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_LongConn", ClientProvider: cp, ServerProvider: sp}) - cp, _ = ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientShortConnPool()) + cp, _ = ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientShortConnPool()) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_ShortConn", ClientProvider: cp, ServerProvider: sp}) - cp, _ = ttstream.NewClientProvider(streamingServiceInfo, ttstream.WithClientMuxConnPool(ttstream.MuxConnConfig{PoolSize: 8, MaxIdleTimeout: time.Millisecond * 1000})) + cp, _ = ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientMuxConnPool(ttstream.MuxConnConfig{PoolSize: 8, MaxIdleTimeout: time.Millisecond * 1000})) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_Mux", ClientProvider: cp, ServerProvider: sp}) } @@ -72,51 +73,84 @@ func TestMain(m *testing.M) { m.Run() } +func NewTestServer(serviceImpl TestService, opts ...server.Option) (string, server.Server, error) { + addr := test.GetLocalAddress() + ln, err := netpoll.CreateListener("tcp", addr) + if err != nil { + return "", nil, err + } + options := []server.Option{ + server.WithListener(ln), + server.WithExitWaitTime(time.Millisecond * 10), + } + options = append(options, opts...) + svr := NewServer(serviceImpl, options...) + go func() { + _ = svr.Run() + }() + test.WaitServerStart(addr) + return addr, svr, nil +} + +func untilEqual(t *testing.T, current *int32, target int32, timeout time.Duration) { + var duration time.Duration + interval := time.Millisecond * 10 + for atomic.LoadInt32(current) != target { + time.Sleep(interval) + duration += interval + if duration > timeout { + t.Fatalf("current(%d) != target(%d)", current, target) + return + } + } +} + +func increaseIfNoError(val *int32, err error) { + if err != nil { + return + } + atomic.AddInt32(val, 1) +} + func TestStreamingBasic(t *testing.T) { for _, tc := range providerTestCases { t.Run(tc.Name, func(t *testing.T) { - concurrency := 100 + concurrency := 10 round := 5 - // === prepare test environment === - addr := test.GetLocalAddress() - ln, err := netpoll.CreateListener("tcp", addr) - test.Assert(t, err == nil, err) - defer ln.Close() // create server + var serverMiddlewareCount int32 var serverStreamCount int32 - waitServerStreamDone := func() { - for atomic.LoadInt32(&serverStreamCount) != 0 { - t.Logf("waitServerStreamDone: %d", atomic.LoadInt32(&serverStreamCount)) - time.Sleep(time.Millisecond * 10) - } - } var serverRecvCount int32 var serverSendCount int32 - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) - // register pingpong service - err = svr.RegisterService(pingpongServiceInfo, new(pingpongService)) - test.Assert(t, err == nil, err) - // register streamingService as ttstreaam provider - err = svr.RegisterService( - streamingServiceInfo, - new(streamingService), + resetServerCount := func() { + atomic.StoreInt32(&serverMiddlewareCount, 0) + atomic.StoreInt32(&serverStreamCount, 0) + atomic.StoreInt32(&serverRecvCount, 0) + atomic.StoreInt32(&serverSendCount, 0) + } + addr, svr, err := NewTestServer( + new(testService), + server.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, req, resp interface{}) (err error) { + err = next(ctx, req, resp) + increaseIfNoError(&serverMiddlewareCount, err) + return err + } + }), + streamxserver.WithProvider(tc.ServerProvider), streamxserver.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { return func(ctx context.Context, stream streamx.Stream, res any) (err error) { err = next(ctx, stream, res) - if err == nil { - atomic.AddInt32(&serverRecvCount, 1) - } + increaseIfNoError(&serverRecvCount, err) return err } }), streamxserver.WithStreamSendMiddleware(func(next streamx.StreamSendEndpoint) streamx.StreamSendEndpoint { return func(ctx context.Context, stream streamx.Stream, req any) (err error) { err = next(ctx, stream, req) - if err == nil { - atomic.AddInt32(&serverSendCount, 1) - } + increaseIfNoError(&serverSendCount, err) return err } }), @@ -124,7 +158,7 @@ func TestStreamingBasic(t *testing.T) { // middleware example: server streaming mode func(next streamx.StreamEndpoint) streamx.StreamEndpoint { return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { - log.Printf("Server middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", + t.Logf("Server middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", reqArgs.Req(), resArgs.Res(), streamArgs) test.Assert(t, streamArgs.Stream() != nil) test.Assert(t, validateMetadata(ctx)) @@ -134,20 +168,29 @@ func TestStreamingBasic(t *testing.T) { test.Assert(t, reqArgs.Req() != nil) test.Assert(t, resArgs.Res() == nil) err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() != nil || err != nil) + if err == nil { + req := reqArgs.Req().(*Request) + res := resArgs.Res().(*Response) + test.DeepEqual(t, req.Message, res.Message) + } case streamx.StreamingClient: test.Assert(t, reqArgs.Req() == nil) test.Assert(t, resArgs.Res() == nil) err = next(ctx, streamArgs, reqArgs, resArgs) - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() != nil || err != nil) + if err == nil { + res := resArgs.Res().(*Response) + test.Assert(t, res.Message != "") + } case streamx.StreamingServer: test.Assert(t, reqArgs.Req() != nil) test.Assert(t, resArgs.Res() == nil) err = next(ctx, streamArgs, reqArgs, resArgs) test.Assert(t, reqArgs.Req() != nil) test.Assert(t, resArgs.Res() == nil) + if err == nil { + req := reqArgs.Req().(*Request) + test.Assert(t, req.Message != "") + } case streamx.StreamingBidirectional: test.Assert(t, reqArgs.Req() == nil) test.Assert(t, resArgs.Res() == nil) @@ -156,44 +199,54 @@ func TestStreamingBasic(t *testing.T) { test.Assert(t, resArgs.Res() == nil) } - log.Printf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v err=%v", + t.Logf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v err=%v", reqArgs.Req(), resArgs.Res(), streamArgs.Stream(), err) - atomic.AddInt32(&serverStreamCount, 1) + increaseIfNoError(&serverStreamCount, err) return err } }, ), ) test.Assert(t, err == nil, err) - go func() { - err := svr.Run() - test.Assert(t, err == nil, err) - }() defer svr.Stop() - test.WaitServerStart(addr) // create client - pingpongClient, err := NewPingPongClient( - "kitex.service.pingpong", + var clientMiddlewareCount int32 + var clientStreamCount int32 + var clientRecvCount int32 + var clientSendCount int32 + resetClientCount := func() { + atomic.StoreInt32(&clientMiddlewareCount, 0) + atomic.StoreInt32(&clientStreamCount, 0) + atomic.StoreInt32(&clientRecvCount, 0) + atomic.StoreInt32(&clientSendCount, 0) + } + cli, err := NewClient( + "kitex.echo.service", client.WithHostPorts(addr), client.WithTransportProtocol(transport.TTHeaderFramed), - client.WithPayloadCodec(thrift.NewThriftCodecWithConfig(thrift.FastRead|thrift.FastWrite|thrift.EnableSkipDecoder)), - ) - test.Assert(t, err == nil, err) - // create streaming client - streamClient, err := NewStreamingClient( - "kitex.service.streaming", + client.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, req, resp interface{}) (err error) { + err = next(ctx, req, resp) + increaseIfNoError(&clientMiddlewareCount, err) + return err + } + }), + streamxclient.WithProvider(tc.ClientProvider), - streamxclient.WithHostPorts(addr), streamxclient.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { return func(ctx context.Context, stream streamx.Stream, res any) (err error) { err = next(ctx, stream, res) + if err == nil { + atomic.AddInt32(&clientRecvCount, 1) + } return err } }), streamxclient.WithStreamSendMiddleware(func(next streamx.StreamSendEndpoint) streamx.StreamSendEndpoint { return func(ctx context.Context, stream streamx.Stream, req any) (err error) { err = next(ctx, stream, req) + increaseIfNoError(&clientSendCount, err) return err } }), @@ -205,20 +258,26 @@ func TestStreamingBasic(t *testing.T) { err = next(ctx, streamArgs, reqArgs, resArgs) test.Assert(t, streamArgs.Stream() != nil) - switch streamArgs.Stream().Mode() { - case streamx.StreamingUnary: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() != nil || err != nil) - case streamx.StreamingClient: - test.Assert(t, reqArgs.Req() == nil, reqArgs.Req()) - test.Assert(t, resArgs.Res() == nil) - case streamx.StreamingServer: - test.Assert(t, reqArgs.Req() != nil) - test.Assert(t, resArgs.Res() == nil) - case streamx.StreamingBidirectional: - test.Assert(t, reqArgs.Req() == nil) - test.Assert(t, resArgs.Res() == nil) + if err == nil { + switch streamArgs.Stream().Mode() { + case streamx.StreamingUnary: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() != nil) + req := reqArgs.Req().(*Request) + res := resArgs.Res().(*Response) + test.DeepEqual(t, req.Message, res.Message) + case streamx.StreamingClient: + test.Assert(t, reqArgs.Req() == nil, reqArgs.Req()) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingServer: + test.Assert(t, reqArgs.Req() != nil) + test.Assert(t, resArgs.Res() == nil) + case streamx.StreamingBidirectional: + test.Assert(t, reqArgs.Req() == nil) + test.Assert(t, resArgs.Res() == nil) + } } + increaseIfNoError(&clientStreamCount, err) return err } }), @@ -236,12 +295,22 @@ func TestStreamingBasic(t *testing.T) { defer wg.Done() req := new(Request) req.Message = "PingPong" - res, err := pingpongClient.PingPong(octx, req) + res, err := cli.PingPong(octx, req) test.Assert(t, err == nil, err) test.Assert(t, req.Message == res.Message, res) }() } wg.Wait() + test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(0)) + test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(0)) + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(0)) + test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(0)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(0)) + test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(0)) + resetServerCount() + resetClientCount() t.Logf("=== Unary ===") for i := 0; i < concurrency; i++ { @@ -251,19 +320,23 @@ func TestStreamingBasic(t *testing.T) { req := new(Request) req.Type = 10000 req.Message = "Unary" - res, err := streamClient.Unary(octx, req) + res, err := cli.Unary(octx, req) test.Assert(t, err == nil, err) test.Assert(t, req.Type == res.Type, res.Type) test.Assert(t, req.Message == res.Message, res.Message) - atomic.AddInt32(&serverStreamCount, -1) }() } wg.Wait() - waitServerStreamDone() + test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(concurrency)) test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(concurrency)) test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(concurrency)) - atomic.StoreInt32(&serverRecvCount, 0) - atomic.StoreInt32(&serverSendCount, 0) + test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(concurrency)) + resetServerCount() + resetClientCount() // client stream t.Logf("=== ClientStream ===") @@ -271,7 +344,7 @@ func TestStreamingBasic(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - ctx, cs, err := streamClient.ClientStream(octx) + ctx, cs, err := cli.ClientStream(octx) test.Assert(t, err == nil, err) for i := 0; i < round; i++ { req := new(Request) @@ -283,16 +356,21 @@ func TestStreamingBasic(t *testing.T) { res, err := cs.CloseAndRecv(ctx) test.Assert(t, err == nil, err) test.Assert(t, res.Message == "ClientStream", res.Message) - atomic.AddInt32(&serverStreamCount, -1) testHeaderAndTrailer(t, cs) }() } wg.Wait() - waitServerStreamDone() - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round)*int32(concurrency)) + untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) + test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round*concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(concurrency)) test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(concurrency)) - atomic.StoreInt32(&serverRecvCount, 0) - atomic.StoreInt32(&serverSendCount, 0) + test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(round*concurrency)) + resetServerCount() + resetClientCount() // server stream t.Logf("=== ServerStream ===") @@ -302,7 +380,7 @@ func TestStreamingBasic(t *testing.T) { defer wg.Done() req := new(Request) req.Message = "ServerStream" - ctx, ss, err := streamClient.ServerStream(octx, req) + ctx, ss, err := cli.ServerStream(octx, req) test.Assert(t, err == nil, err) received := 0 for { @@ -315,14 +393,20 @@ func TestStreamingBasic(t *testing.T) { t.Logf("Client ServerStream recv: %v", res) } testHeaderAndTrailer(t, ss) - atomic.AddInt32(&serverStreamCount, -1) }() } wg.Wait() - waitServerStreamDone() + untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) + test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(concurrency)) test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(concurrency)) - atomic.StoreInt32(&serverRecvCount, 0) - atomic.StoreInt32(&serverSendCount, 0) + test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(round*concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(round*concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(concurrency)) + resetServerCount() + resetClientCount() // bidi stream t.Logf("=== BidiStream ===") @@ -330,7 +414,7 @@ func TestStreamingBasic(t *testing.T) { wg.Add(3) go func() { defer wg.Done() - ctx, bs, err := streamClient.BidiStream(octx) + ctx, bs, err := cli.BidiStream(octx) test.Assert(t, err == nil, err) msg := "BidiStream" go func() { @@ -359,20 +443,26 @@ func TestStreamingBasic(t *testing.T) { test.Assert(t, i == round, i) }() testHeaderAndTrailer(t, bs) - atomic.AddInt32(&serverStreamCount, -1) }() } wg.Wait() - waitServerStreamDone() + untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) + untilEqual(t, &clientStreamCount, int32(concurrency), time.Second) + test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(concurrency)) test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round*concurrency)) + test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(round*concurrency)) test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(round*concurrency)) - atomic.StoreInt32(&serverRecvCount, 0) - atomic.StoreInt32(&serverSendCount, 0) + test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(round*concurrency)) + resetServerCount() + resetClientCount() t.Logf("=== UnaryWithErr normalErr ===") req := new(Request) req.Type = normalErr - res, err := streamClient.UnaryWithErr(octx, req) + res, err := cli.UnaryWithErr(octx, req) test.Assert(t, res == nil, res) test.Assert(t, err != nil, err) assertNormalErr(t, err) @@ -380,13 +470,13 @@ func TestStreamingBasic(t *testing.T) { t.Logf("=== UnaryWithErr bizErr ===") req = new(Request) req.Type = bizErr - res, err = streamClient.UnaryWithErr(octx, req) + res, err = cli.UnaryWithErr(octx, req) test.Assert(t, res == nil, res) test.Assert(t, err != nil, err) assertBizErr(t, err) t.Logf("=== ClientStreamWithErr normalErr ===") - ctx, cliStream, err := streamClient.ClientStreamWithErr(octx) + ctx, cliStream, err := cli.ClientStreamWithErr(octx) test.Assert(t, err == nil, err) test.Assert(t, cliStream != nil, cliStream) req = new(Request) @@ -399,7 +489,7 @@ func TestStreamingBasic(t *testing.T) { assertNormalErr(t, err) t.Logf("=== ClientStreamWithErr bizErr ===") - ctx, cliStream, err = streamClient.ClientStreamWithErr(octx) + ctx, cliStream, err = cli.ClientStreamWithErr(octx) test.Assert(t, err == nil, err) test.Assert(t, cliStream != nil, cliStream) req = new(Request) @@ -414,7 +504,7 @@ func TestStreamingBasic(t *testing.T) { t.Logf("=== ServerStreamWithErr normalErr ===") req = new(Request) req.Type = normalErr - ctx, svrStream, err := streamClient.ServerStreamWithErr(octx, req) + ctx, svrStream, err := cli.ServerStreamWithErr(octx, req) test.Assert(t, err == nil, err) test.Assert(t, svrStream != nil, svrStream) res, err = svrStream.Recv(ctx) @@ -425,7 +515,7 @@ func TestStreamingBasic(t *testing.T) { t.Logf("=== ServerStreamWithErr bizErr ===") req = new(Request) req.Type = bizErr - ctx, svrStream, err = streamClient.ServerStreamWithErr(octx, req) + ctx, svrStream, err = cli.ServerStreamWithErr(octx, req) test.Assert(t, err == nil, err) test.Assert(t, svrStream != nil, svrStream) res, err = svrStream.Recv(ctx) @@ -434,7 +524,7 @@ func TestStreamingBasic(t *testing.T) { assertBizErr(t, err) t.Logf("=== BidiStreamWithErr normalErr ===") - ctx, bidiStream, err := streamClient.BidiStreamWithErr(octx) + ctx, bidiStream, err := cli.BidiStreamWithErr(octx) test.Assert(t, err == nil, err) test.Assert(t, bidiStream != nil, bidiStream) req = new(Request) @@ -447,7 +537,7 @@ func TestStreamingBasic(t *testing.T) { assertNormalErr(t, err) t.Logf("=== BidiStreamWithErr bizErr ===") - ctx, bidiStream, err = streamClient.BidiStreamWithErr(octx) + ctx, bidiStream, err = cli.BidiStreamWithErr(octx) test.Assert(t, err == nil, err) test.Assert(t, bidiStream != nil, bidiStream) req = new(Request) @@ -460,7 +550,7 @@ func TestStreamingBasic(t *testing.T) { assertBizErr(t, err) t.Logf("=== Timeout by Ctx ===") - ctx, bs, err := streamClient.BidiStream(octx) + ctx, bs, err := cli.BidiStream(octx) test.Assert(t, err == nil, err) req = new(Request) req.Message = string(make([]byte, 1024)) @@ -476,13 +566,13 @@ func TestStreamingBasic(t *testing.T) { // timeout by client WithRecvTimeout t.Logf("=== Timeout by WithRecvTimeout ===") - streamClient, _ = NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithHostPorts(addr), + cli, _ = NewClient( + "kitex.service.test", + client.WithHostPorts(addr), streamxclient.WithProvider(tc.ClientProvider), - streamxclient.WithRecvTimeout(time.Nanosecond), + streamxclient.WithStreamRecvTimeout(time.Nanosecond), ) - ctx, bs, err = streamClient.BidiStream(octx) + ctx, bs, err = cli.BidiStream(octx) test.Assert(t, err == nil, err) req = new(Request) req.Message = string(make([]byte, 1024)) @@ -494,12 +584,12 @@ func TestStreamingBasic(t *testing.T) { err = bs.CloseSend(ctx) test.Assert(t, err == nil, err) - streamClient = nil + cli = nil }) } } -func TestStreamingGoroutineLeak(t *testing.T) { +func TestStreamingException(t *testing.T) { for _, tc := range providerTestCases { t.Run(tc.Name, func(t *testing.T) { addr := test.GetLocalAddress() @@ -507,7 +597,55 @@ func TestStreamingGoroutineLeak(t *testing.T) { defer ln.Close() // create server - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) + addr, svr, err := NewTestServer( + new(testService), + streamxserver.WithProvider(tc.ServerProvider), + ) + test.Assert(t, err == nil, err) + defer svr.Stop() + + var circuitBreaker int32 + circuitBreakerErr := fmt.Errorf("circuitBreaker on") + cli, _ := NewClient( + "kitex.echo.service", + client.WithHostPorts(addr), + + streamxclient.WithProvider(tc.ClientProvider), + streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + ri := rpcinfo.GetRPCInfo(ctx) + test.Assert(t, ri.To().Address() != nil) + if atomic.LoadInt32(&circuitBreaker) > 0 { + return circuitBreakerErr + } + return next(ctx, streamArgs, reqArgs, resArgs) + } + }), + ) + octx := context.Background() + + // assert circuitBreaker error + atomic.StoreInt32(&circuitBreaker, 1) + _, _, err = cli.BidiStream(octx) + test.Assert(t, errors.Is(err, circuitBreakerErr), err) + atomic.StoreInt32(&circuitBreaker, 0) + + // assert context deadline error + ctx, cancel := context.WithTimeout(octx, time.Millisecond) + ctx, bs, err := cli.BidiStream(ctx) + test.Assert(t, err == nil, err) + res, err := bs.Recv(ctx) + cancel() + test.Assert(t, res == nil && err != nil, res, err) + test.Assert(t, errors.Is(err, ctx.Err()), err) + test.Assert(t, errors.Is(err, context.DeadlineExceeded), err) + }) + } +} + +func TestStreamingGoroutineLeak(t *testing.T) { + for _, tc := range providerTestCases { + t.Run(tc.Name, func(t *testing.T) { var streamStarted int32 waitStreamStarted := func(streamWaited int) { for { @@ -519,8 +657,10 @@ func TestStreamingGoroutineLeak(t *testing.T) { time.Sleep(time.Millisecond * 10) } } - _ = svr.RegisterService( - streamingServiceInfo, new(streamingService), + + // create server + addr, svr, err := NewTestServer( + new(testService), streamxserver.WithProvider(tc.ServerProvider), streamxserver.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { @@ -529,15 +669,12 @@ func TestStreamingGoroutineLeak(t *testing.T) { } }), ) - go func() { - _ = svr.Run() - }() + test.Assert(t, err == nil, err) defer svr.Stop() - test.WaitServerStart(addr) - streamClient, _ := NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithHostPorts(addr), + cli, _ := NewClient( + "kitex.test.service", + client.WithHostPorts(addr), streamxclient.WithProvider(tc.ClientProvider), ) octx := context.Background() @@ -547,7 +684,7 @@ func TestStreamingGoroutineLeak(t *testing.T) { var wg sync.WaitGroup for i := 0; i < 12; i++ { wg.Add(1) - ctx, bs, err := streamClient.BidiStream(octx) + ctx, bs, err := cli.BidiStream(octx) test.Assert(t, err == nil, err) req := new(Request) req.Message = string(make([]byte, 1024)) @@ -572,13 +709,13 @@ func TestStreamingGoroutineLeak(t *testing.T) { streamList := make([]streamx.ClientStream, streams) atomic.StoreInt32(&streamStarted, 0) for i := 0; i < streams; i++ { - _, bs, err := streamClient.BidiStream(octx) + _, bs, err := cli.BidiStream(octx) test.Assert(t, err == nil, err) streamList[i] = bs } waitStreamStarted(streams) // before GC - test.Assert(t, runtime.NumGoroutine() > streams, runtime.NumGoroutine()) + test.Assert(t, runtime.NumGoroutine() >= streams, runtime.NumGoroutine(), streams) // after GC for i := 0; i < streams; i++ { streamList[i] = nil @@ -586,7 +723,7 @@ func TestStreamingGoroutineLeak(t *testing.T) { for runtime.NumGoroutine() > ngBefore { t.Logf("ngCurrent=%d > ngBefore=%d", runtime.NumGoroutine(), ngBefore) runtime.GC() - time.Sleep(time.Millisecond * 50) + time.Sleep(time.Millisecond * 10) } t.Logf("=== Checking Streams Called and GCed ===") @@ -597,7 +734,7 @@ func TestStreamingGoroutineLeak(t *testing.T) { go func() { defer wg.Done() - ctx, bs, err := streamClient.BidiStream(octx) + ctx, bs, err := cli.BidiStream(octx) test.Assert(t, err == nil, err) req := new(Request) req.Message = msg @@ -609,8 +746,6 @@ func TestStreamingGoroutineLeak(t *testing.T) { err = bs.CloseSend(ctx) test.Assert(t, err == nil, err) test.Assert(t, res.Message == msg, res.Message) - - testHeaderAndTrailer(t, bs) }() } wg.Wait() @@ -630,7 +765,7 @@ func TestStreamingGoroutineLeak(t *testing.T) { req := new(Request) req.Message = msg - ctx, ss, err := streamClient.ServerStream(octx, req) + ctx, ss, err := cli.ServerStream(octx, req) test.Assert(t, err == nil, err) for { @@ -641,7 +776,6 @@ func TestStreamingGoroutineLeak(t *testing.T) { test.Assert(t, err == nil, err) test.Assert(t, res.Message == msg, res.Message) } - testHeaderAndTrailer(t, ss) }() } wg.Wait() @@ -653,60 +787,3 @@ func TestStreamingGoroutineLeak(t *testing.T) { }) } } - -func TestStreamingException(t *testing.T) { - for _, tc := range providerTestCases { - t.Run(tc.Name, func(t *testing.T) { - addr := test.GetLocalAddress() - ln, _ := netpoll.CreateListener("tcp", addr) - defer ln.Close() - - // create server - svr := server.NewServer(server.WithListener(ln), server.WithExitWaitTime(time.Millisecond*10)) - _ = svr.RegisterService( - streamingServiceInfo, new(streamingService), - streamxserver.WithProvider(tc.ServerProvider), - ) - go func() { - _ = svr.Run() - }() - defer svr.Stop() - test.WaitServerStart(addr) - - var circuitBreaker int32 - circuitBreakerErr := fmt.Errorf("circuitBreaker on") - streamClient, _ := NewStreamingClient( - "kitex.service.streaming", - streamxclient.WithHostPorts(addr), - streamxclient.WithProvider(tc.ClientProvider), - streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { - return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { - ri := rpcinfo.GetRPCInfo(ctx) - test.Assert(t, ri.To().Address() != nil) - if atomic.LoadInt32(&circuitBreaker) > 0 { - return circuitBreakerErr - } - return next(ctx, streamArgs, reqArgs, resArgs) - } - }), - ) - octx := context.Background() - - // assert circuitBreaker error - atomic.StoreInt32(&circuitBreaker, 1) - _, _, err := streamClient.BidiStream(octx) - test.Assert(t, errors.Is(err, circuitBreakerErr), err) - atomic.StoreInt32(&circuitBreaker, 0) - - // assert context deadline error - ctx, cancel := context.WithTimeout(octx, time.Millisecond) - ctx, bs, err := streamClient.BidiStream(ctx) - test.Assert(t, err == nil, err) - res, err := bs.Recv(ctx) - cancel() - test.Assert(t, res == nil && err != nil, res, err) - test.Assert(t, errors.Is(err, ctx.Err()), err) - test.Assert(t, errors.Is(err, context.DeadlineExceeded), err) - }) - } -} diff --git a/server/server.go b/server/server.go index 0e939b2d68..d056299c07 100644 --- a/server/server.go +++ b/server/server.go @@ -27,12 +27,11 @@ import ( "sync" "time" + "github.com/cloudwego/localsession/backup" + "github.com/cloudwego/kitex/pkg/remote/trans/detection" "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" - "github.com/cloudwego/kitex/pkg/streamx" - - "github.com/cloudwego/localsession/backup" internal_server "github.com/cloudwego/kitex/internal/server" "github.com/cloudwego/kitex/pkg/acl" @@ -216,17 +215,6 @@ func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler inter } registerOpts := internal_server.NewRegisterOptions(opts) - // add trace middlewares - ehandler := s.opt.TracerCtl.GetStreamEventHandler() - if ehandler != nil { - registerOpts.StreamRecvMiddlewares = append( - registerOpts.StreamRecvMiddlewares, streamx.NewStreamRecvStatMiddleware(ehandler), - ) - registerOpts.StreamSendMiddlewares = append( - registerOpts.StreamSendMiddlewares, streamx.NewStreamSendStatMiddleware(ehandler), - ) - } - // register service if err := s.svcs.addService(svcInfo, handler, registerOpts); err != nil { panic(err.Error()) @@ -419,16 +407,7 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint { // set session backup.BackupCtx(ctx) - handler := svc.handler - if minfo.IsStreaming() && svcInfo.Extra["streamx"] != nil { - handler = streamx.StreamHandler{ - Handler: svc.handler, - StreamMiddleware: svc.SMW, - StreamRecvMiddleware: svc.SRecvMW, - StreamSendMiddleware: svc.SSendMW, - } - } - err = implHandlerFunc(ctx, handler, args, resp) + err = implHandlerFunc(ctx, svc.handler, args, resp) if err != nil { if bizErr, ok := kerrors.FromBizStatusError(err); ok { if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { @@ -569,12 +548,10 @@ func (s *server) newSvrTransHandler() (handler remote.ServerTransHandler, err er transHdlrFactory := s.opt.RemoteOpt.SvrHandlerFactory if transHdlrFactory == nil { candidateFactories := make([]remote.ServerTransHandlerFactory, 0) - for _, svc := range s.svcs.svcMap { - if svc.streamingProvider != nil { - candidateFactories = append(candidateFactories, - streamxstrans.NewSvrTransHandlerFactory(svc.streamingProvider), - ) - } + if s.opt.StreamX.Provider != nil { + candidateFactories = append(candidateFactories, + streamxstrans.NewSvrTransHandlerFactory(s.opt.StreamX.Provider), + ) } candidateFactories = append(candidateFactories, nphttp2.NewSvrTransHandlerFactory(), diff --git a/server/service.go b/server/service.go index 6dbc302f62..9af7c3315d 100644 --- a/server/service.go +++ b/server/service.go @@ -22,25 +22,20 @@ import ( "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/serviceinfo" - "github.com/cloudwego/kitex/pkg/streamx" ) type serviceMiddlewares struct { - MW endpoint.Middleware - SMW streamx.StreamMiddleware - SRecvMW streamx.StreamRecvMiddleware - SSendMW streamx.StreamSendMiddleware + MW endpoint.Middleware } type service struct { - svcInfo *serviceinfo.ServiceInfo - handler interface{} - streamingProvider streamx.ServerProvider + svcInfo *serviceinfo.ServiceInfo + handler interface{} serviceMiddlewares } -func newService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, provider streamx.ServerProvider, smw serviceMiddlewares) *service { - return &service{svcInfo: svcInfo, handler: handler, streamingProvider: provider, serviceMiddlewares: smw} +func newService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, smw serviceMiddlewares) *service { + return &service{svcInfo: svcInfo, handler: handler, serviceMiddlewares: smw} } type services struct { @@ -59,25 +54,13 @@ func newServices() *services { } func (s *services) addService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, registerOpts *RegisterOptions) error { - // prepare service provider - serviceProvider := registerOpts.Provider - // prepare serviceMiddlewares var serviceMWs serviceMiddlewares if len(registerOpts.Middlewares) > 0 { serviceMWs.MW = endpoint.Chain(registerOpts.Middlewares...) } - if len(registerOpts.StreamMiddlewares) > 0 { - serviceMWs.SMW = streamx.StreamMiddlewareChain(registerOpts.StreamMiddlewares...) - } - if len(registerOpts.StreamRecvMiddlewares) > 0 { - serviceMWs.SRecvMW = streamx.StreamRecvMiddlewareChain(registerOpts.StreamRecvMiddlewares...) - } - if len(registerOpts.StreamSendMiddlewares) > 0 { - serviceMWs.SSendMW = streamx.StreamSendMiddlewareChain(registerOpts.StreamSendMiddlewares...) - } - svc := newService(svcInfo, handler, serviceProvider, serviceMWs) + svc := newService(svcInfo, handler, serviceMWs) if registerOpts.IsFallbackService { if s.fallbackSvc != nil { return fmt.Errorf("multiple fallback services cannot be registered. [%s] is already registered as a fallback service", s.fallbackSvc.svcInfo.ServiceName) diff --git a/server/stream.go b/server/stream.go index 1043c69ff6..59b844bb1a 100644 --- a/server/stream.go +++ b/server/stream.go @@ -21,17 +21,36 @@ import ( "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/pkg/streamx" ) func (s *server) initStreamMiddlewares(ctx context.Context) { - // for old version streaming + // === for old version streaming === s.opt.Streaming.EventHandler = s.opt.TracerCtl.GetStreamEventHandler() s.opt.Streaming.InitMiddlewares(ctx) + + // === for streamx version streaming === + // add tracing middlewares + ehandler := s.opt.TracerCtl.GetStreamEventHandler() + if ehandler != nil { + s.opt.StreamX.StreamRecvMiddlewares = append( + s.opt.StreamX.StreamRecvMiddlewares, streamx.NewStreamRecvStatMiddleware(ehandler), + ) + s.opt.StreamX.StreamSendMiddlewares = append( + s.opt.StreamX.StreamSendMiddlewares, streamx.NewStreamSendStatMiddleware(ehandler), + ) + } } func (s *server) buildStreamInvokeChain() { + // === for old version streaming === s.opt.RemoteOpt.RecvEndpoint = s.opt.Streaming.BuildRecvInvokeChain(s.invokeRecvEndpoint()) s.opt.RemoteOpt.SendEndpoint = s.opt.Streaming.BuildSendInvokeChain(s.invokeSendEndpoint()) + + // === for streamx version streaming === + s.opt.RemoteOpt.StreamMiddleware = streamx.StreamMiddlewareChain(s.opt.StreamX.StreamMiddlewares...) + s.opt.RemoteOpt.StreamRecvMiddleware = streamx.StreamRecvMiddlewareChain(s.opt.StreamX.StreamRecvMiddlewares...) + s.opt.RemoteOpt.StreamSendMiddleware = streamx.StreamSendMiddlewareChain(s.opt.StreamX.StreamSendMiddlewares...) } func (s *server) invokeRecvEndpoint() endpoint.RecvEndpoint { diff --git a/server/streamxserver/option.go b/server/streamxserver/option.go new file mode 100644 index 0000000000..109f0c096a --- /dev/null +++ b/server/streamxserver/option.go @@ -0,0 +1,48 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamxserver + +import ( + internal_server "github.com/cloudwego/kitex/internal/server" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/kitex/server" +) + +func WithProvider(provider streamx.ServerProvider) server.Option { + return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { + o.StreamX.Provider = provider + }} +} + +func WithStreamMiddleware(mw streamx.StreamMiddleware) server.Option { + return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { + o.StreamX.StreamMiddlewares = append(o.StreamX.StreamMiddlewares, mw) + }} +} + +func WithStreamRecvMiddleware(mw streamx.StreamRecvMiddleware) server.Option { + return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { + o.StreamX.StreamRecvMiddlewares = append(o.StreamX.StreamRecvMiddlewares, mw) + }} +} + +func WithStreamSendMiddleware(mw streamx.StreamSendMiddleware) server.Option { + return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { + o.StreamX.StreamSendMiddlewares = append(o.StreamX.StreamSendMiddlewares, mw) + }} +} diff --git a/server/streamxserver/server_gen.go b/server/streamxserver/server_gen.go index 74d47067af..bb86793b35 100644 --- a/server/streamxserver/server_gen.go +++ b/server/streamxserver/server_gen.go @@ -1,127 +1,150 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. package streamxserver import ( "context" "errors" - "reflect" - "sync" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" ) -var invokerCache sync.Map +var errServerStreamArgsNotFound = errors.New("stream args not found") -func InvokeStream[Req, Res any]( - ctx context.Context, smode serviceinfo.StreamingMode, - handler any, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs, -) (err error) { - // prepare args +func prepareInvokeStream[Req, Res any](sArgs streamx.StreamArgs) (*streamx.GenericServerStream[Req, Res], streamx.StreamMiddleware) { + gs := streamx.NewGenericServerStream[Req, Res](sArgs.Stream().(streamx.ServerStream)) + swArgs, ok := sArgs.(streamx.StreamMiddlewaresArgs) + if !ok { + return gs, nil + } + sMW, recvMW, sendMW := swArgs.Middlewares() + gs.SetStreamRecvMiddleware(recvMW) + gs.SetStreamSendMiddleware(sendMW) + return gs, sMW +} + +func InvokeUnaryHandler[Req, Res any]( + ctx context.Context, + reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs, + methodHandler streamx.UnaryHandler[Req, Res], +) error { sArgs := streamx.GetStreamArgsFromContext(ctx) if sArgs == nil { - return errors.New("server stream is nil") + return errServerStreamArgsNotFound } - shandler := handler.(streamx.StreamHandler) - gs := streamx.NewGenericServerStream[Req, Res](sArgs.Stream().(streamx.ServerStream)) - gs.SetStreamRecvMiddleware(shandler.StreamRecvMiddleware) - gs.SetStreamSendMiddleware(shandler.StreamSendMiddleware) + gs, sMW := prepareInvokeStream[Req, Res](sArgs) - // before handler - var req *Req - var res *Res - switch smode { - case serviceinfo.StreamingUnary, serviceinfo.StreamingServer: - req, err = gs.Recv(ctx) + // before handler call + req, err := gs.Recv(ctx) + if err != nil { + return err + } + reqArgs.SetReq(req) + + // handler call + invoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + res, err := methodHandler(ctx, req) if err != nil { return err } - reqArgs.SetReq(req) - default: + resArgs.SetRes(res) + return gs.Send(ctx, res) + } + if sMW != nil { + err = sMW(invoke)(ctx, sArgs, reqArgs, resArgs) + } else { + err = invoke(ctx, sArgs, reqArgs, resArgs) } + return err +} + +func InvokeClientStreamHandler[Req, Res any]( + ctx context.Context, + reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs, + methodHandler streamx.ClientStreamingHandler[Req, Res], +) (err error) { + sArgs := streamx.GetStreamArgsFromContext(ctx) + if sArgs == nil { + return errServerStreamArgsNotFound + } + gs, sMW := prepareInvokeStream[Req, Res](sArgs) + invoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + res, err := methodHandler(ctx, gs) + if err != nil { + return err + } + resArgs.SetRes(res) + return gs.Send(ctx, res) + } + if sMW != nil { + err = sMW(invoke)(ctx, sArgs, reqArgs, resArgs) + } else { + err = invoke(ctx, sArgs, reqArgs, resArgs) + } + return err +} + +func InvokeServerStreamHandler[Req, Res any]( + ctx context.Context, + reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs, + methodHandler streamx.ServerStreamingHandler[Req, Res], +) (err error) { + sArgs := streamx.GetStreamArgsFromContext(ctx) + if sArgs == nil { + return errServerStreamArgsNotFound + } + gs, sMW := prepareInvokeStream[Req, Res](sArgs) + + // before handler call + req, err := gs.Recv(ctx) + if err != nil { + return err + } + reqArgs.SetReq(req) // handler call - cacheKey := reflect.TypeOf(shandler.Handler).String() + sArgs.Stream().Method() - var mhandler reflect.Value - if v, ok := invokerCache.Load(cacheKey); ok { - mhandler = v.(reflect.Value) + invoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + return methodHandler(ctx, req, gs) + } + if sMW != nil { + err = sMW(invoke)(ctx, sArgs, reqArgs, resArgs) } else { - rhandler := reflect.ValueOf(shandler.Handler) - mhandler = rhandler.MethodByName(sArgs.Stream().Method()) - invokerCache.Store(cacheKey, mhandler) + err = invoke(ctx, sArgs, reqArgs, resArgs) } + return err +} - streamInvoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { - switch smode { - case serviceinfo.StreamingUnary: - called := mhandler.Call([]reflect.Value{ - reflect.ValueOf(ctx), - reflect.ValueOf(req), - }) - _res, _err := called[0].Interface(), called[1].Interface() - if _err != nil { - return _err.(error) - } - res = _res.(*Res) - if err = gs.SendAndClose(ctx, res); err != nil { - return err - } - resArgs.SetRes(res) - case serviceinfo.StreamingClient: - called := mhandler.Call([]reflect.Value{ - reflect.ValueOf(ctx), - reflect.ValueOf(gs), - }) - _res, _err := called[0].Interface(), called[1].Interface() - if _err != nil { - return _err.(error) - } - res = _res.(*Res) - if err = gs.Send(ctx, res); err != nil { - return err - } - resArgs.SetRes(res) - case serviceinfo.StreamingServer: - called := mhandler.Call([]reflect.Value{ - reflect.ValueOf(ctx), - reflect.ValueOf(req), - reflect.ValueOf(gs), - }) - _err := called[0].Interface() - if _err != nil { - return _err.(error) - } - case serviceinfo.StreamingBidirectional: - called := mhandler.Call([]reflect.Value{ - reflect.ValueOf(ctx), - reflect.ValueOf(gs), - }) - _err := called[0].Interface() - if _err != nil { - return _err.(error) - } - } - return nil +func InvokeBidiStreamHandler[Req, Res any]( + ctx context.Context, + reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs, + methodHandler streamx.BidiStreamingHandler[Req, Res], +) (err error) { + sArgs := streamx.GetStreamArgsFromContext(ctx) + if sArgs == nil { + return errServerStreamArgsNotFound + } + gs, sMW := prepareInvokeStream[Req, Res](sArgs) + + // handler call + invoke := func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + return methodHandler(ctx, gs) } - if shandler.StreamMiddleware != nil { - err = shandler.StreamMiddleware(streamInvoke)(ctx, sArgs, reqArgs, resArgs) + if sMW != nil { + err = sMW(invoke)(ctx, sArgs, reqArgs, resArgs) } else { - err = streamInvoke(ctx, sArgs, reqArgs, resArgs) + err = invoke(ctx, sArgs, reqArgs, resArgs) } return err } diff --git a/server/streamxserver/server_option.go b/server/streamxserver/server_option.go deleted file mode 100644 index 6869dcc8b5..0000000000 --- a/server/streamxserver/server_option.go +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package streamxserver - -import ( - "net" - - internal_server "github.com/cloudwego/kitex/internal/server" - "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/server" -) - -type ( - Option internal_server.Option - Options = internal_server.Options -) - -func WithListener(ln net.Listener) Option { - return ConvertNativeServerOption(server.WithListener(ln)) -} - -func WithStreamMiddleware(mw streamx.StreamMiddleware) server.RegisterOption { - return server.RegisterOption{F: func(o *internal_server.RegisterOptions) { - o.StreamMiddlewares = append(o.StreamMiddlewares, mw) - }} -} - -func WithStreamRecvMiddleware(mw streamx.StreamRecvMiddleware) server.RegisterOption { - return server.RegisterOption{F: func(o *internal_server.RegisterOptions) { - o.StreamRecvMiddlewares = append(o.StreamRecvMiddlewares, mw) - }} -} - -func WithStreamSendMiddleware(mw streamx.StreamSendMiddleware) server.RegisterOption { - return server.RegisterOption{F: func(o *internal_server.RegisterOptions) { - o.StreamSendMiddlewares = append(o.StreamSendMiddlewares, mw) - }} -} - -func WithProvider(provider streamx.ServerProvider) server.RegisterOption { - return server.RegisterOption{F: func(o *internal_server.RegisterOptions) { - o.Provider = provider - }} -} - -func ConvertNativeServerOption(o internal_server.Option) Option { - return Option{F: o.F} -} - -func ConvertStreamXServerOption(o Option) internal_server.Option { - return internal_server.Option{F: o.F} -} From f06ff05d10f64441d4836448d3521fef2f24d967 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Fri, 8 Nov 2024 16:31:51 +0800 Subject: [PATCH 17/34] chore: add comments --- client/client.go | 9 ++---- client/service_inline.go | 2 +- client/stream.go | 12 ++++---- client/streamclient/client_option.go | 2 +- client/streamxclient/client_gen.go | 1 + client/streamxclient/client_option.go | 30 +++++++++---------- .../streamxcallopt/call_option.go | 8 +++++ internal/client/option.go | 13 ++++++-- internal/server/option.go | 2 +- internal/server/streamx_config.go | 2 +- pkg/remote/option.go | 2 +- pkg/remote/remotecli/stream.go | 7 ++--- pkg/remote/remotesvr/server.go | 3 +- pkg/remote/remotesvr/server_test.go | 11 ++----- pkg/remote/trans/streamx/server_handler.go | 4 +-- pkg/streamx/client_options.go | 9 ------ server/server.go | 2 +- 17 files changed, 55 insertions(+), 64 deletions(-) diff --git a/client/client.go b/client/client.go index adef12ac6b..c2fdf25eb4 100644 --- a/client/client.go +++ b/client/client.go @@ -99,17 +99,12 @@ func (kf *kcFinalizerClient) Call(ctx context.Context, method string, request, r // NewClient creates a kitex.Client with the given ServiceInfo, it is from generated code. func NewClient(svcInfo *serviceinfo.ServiceInfo, opts ...Option) (Client, error) { - nopts := client.NewOptions(opts) - return NewClientWithOptions(svcInfo, nopts) -} - -func NewClientWithOptions(svcInfo *serviceinfo.ServiceInfo, opts *Options) (Client, error) { if svcInfo == nil { return nil, errors.New("NewClient: no service info") } kc := &kcFinalizerClient{kClient: &kClient{}} kc.svcInfo = svcInfo - kc.opt = opts + kc.opt = client.NewOptions(opts) if err := kc.init(); err != nil { _ = kc.Close() return nil, err @@ -755,7 +750,7 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf } // streamx config - sopt := opt.StreamXOptions + sopt := opt.StreamX if sopt.RecvTimeout > 0 { cfg.SetStreamRecvTimeout(sopt.RecvTimeout) } diff --git a/client/service_inline.go b/client/service_inline.go index da7c0c3003..c807d749e3 100644 --- a/client/service_inline.go +++ b/client/service_inline.go @@ -130,7 +130,7 @@ func (kc *serviceInlineClient) Call(ctx context.Context, method string, request, } kc.opt.TracerCtl.DoFinish(ctx, ri, reportErr) // If the user start a new goroutine and return before endpoint finished, it may cause panic. - // For example,, if the user writes a timeout StreamMiddleware and times out, rpcinfo will be recycled, + // For example,, if the user writes a timeout Middleware and times out, rpcinfo will be recycled, // but in fact, rpcinfo is still being used when it is executed inside // So if endpoint returns err, client won't recycle rpcinfo. if reportErr == nil { diff --git a/client/stream.go b/client/stream.go index 073dbc4f73..5e2219bcc0 100644 --- a/client/stream.go +++ b/client/stream.go @@ -81,14 +81,14 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { sendEndpoint := kc.opt.Streaming.BuildSendInvokeChain(kc.invokeSendEndpoint()) // streamx version streaming mw - kc.sxStreamMW = streamx.StreamMiddlewareChain(kc.opt.StreamXOptions.StreamMWs...) + kc.sxStreamMW = streamx.StreamMiddlewareChain(kc.opt.StreamX.StreamMWs...) eventHandler := kc.opt.TracerCtl.GetStreamEventHandler() if eventHandler != nil { - kc.opt.StreamXOptions.StreamRecvMWs = append(kc.opt.StreamXOptions.StreamRecvMWs, streamx.NewStreamRecvStatMiddleware(eventHandler)) - kc.opt.StreamXOptions.StreamSendMWs = append(kc.opt.StreamXOptions.StreamSendMWs, streamx.NewStreamSendStatMiddleware(eventHandler)) + kc.opt.StreamX.StreamRecvMWs = append(kc.opt.StreamX.StreamRecvMWs, streamx.NewStreamRecvStatMiddleware(eventHandler)) + kc.opt.StreamX.StreamSendMWs = append(kc.opt.StreamX.StreamSendMWs, streamx.NewStreamSendStatMiddleware(eventHandler)) } - kc.sxStreamRecvMW = streamx.StreamRecvMiddlewareChain(kc.opt.StreamXOptions.StreamRecvMWs...) - kc.sxStreamSendMW = streamx.StreamSendMiddlewareChain(kc.opt.StreamXOptions.StreamSendMWs...) + kc.sxStreamRecvMW = streamx.StreamRecvMiddlewareChain(kc.opt.StreamX.StreamRecvMWs...) + kc.sxStreamSendMW = streamx.StreamSendMiddlewareChain(kc.opt.StreamX.StreamSendMWs...) return func(ctx context.Context, req, resp interface{}) (err error) { // req and resp as &streaming.Stream @@ -101,7 +101,7 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { // streamx API if cs, ok := st.(streamx.Stream); ok { streamArgs := resp.(streamx.StreamArgs) - // 此后的中间件才会有 Stream + // the middlewares can get real Stream after set streamx.AsMutableStreamArgs(streamArgs).SetStream(cs) return nil } diff --git a/client/streamclient/client_option.go b/client/streamclient/client_option.go index 0cb6ba94a7..95f161ac49 100644 --- a/client/streamclient/client_option.go +++ b/client/streamclient/client_option.go @@ -42,7 +42,7 @@ func WithSuite(suite client.Suite) Option { // WithMiddleware adds middleware for client to handle request. // NOTE: for streaming APIs (bidirectional, client, server), req is not valid, resp is *streaming.Result -// If you want to intercept recv/send calls, please use Recv/Send StreamMiddleware +// If you want to intercept recv/send calls, please use Recv/Send Middleware func WithMiddleware(mw endpoint.Middleware) Option { return ConvertOptionFrom(client.WithMiddleware(mw)) } diff --git a/client/streamxclient/client_gen.go b/client/streamxclient/client_gen.go index a5ff65b64c..fedc42b12d 100644 --- a/client/streamxclient/client_gen.go +++ b/client/streamxclient/client_gen.go @@ -25,6 +25,7 @@ import ( "github.com/cloudwego/kitex/pkg/streamx" ) +// InvokeStream create a new client stream and wrapped related middlewares func InvokeStream[Req, Res any]( ctx context.Context, cli client.StreamX, smode serviceinfo.StreamingMode, method string, req *Req, res *Res, callOptions ...streamxcallopt.CallOption, diff --git a/client/streamxclient/client_option.go b/client/streamxclient/client_option.go index 134f0a7483..7005622c43 100644 --- a/client/streamxclient/client_option.go +++ b/client/streamxclient/client_option.go @@ -24,34 +24,32 @@ import ( "github.com/cloudwego/kitex/pkg/utils" ) -type Option = internal_client.Option - -func WithProvider(pvd streamx.ClientProvider) Option { - return Option{F: func(o *internal_client.Options, di *utils.Slice) { +func WithProvider(pvd streamx.ClientProvider) internal_client.Option { + return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { o.RemoteOpt.Provider = pvd }} } -func WithStreamRecvTimeout(timeout time.Duration) Option { - return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.StreamXOptions.RecvTimeout = timeout +func WithStreamRecvTimeout(timeout time.Duration) internal_client.Option { + return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.StreamX.RecvTimeout = timeout }} } -func WithStreamMiddleware(smw streamx.StreamMiddleware) Option { - return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.StreamXOptions.StreamMWs = append(o.StreamXOptions.StreamMWs, smw) +func WithStreamMiddleware(smw streamx.StreamMiddleware) internal_client.Option { + return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.StreamX.StreamMWs = append(o.StreamX.StreamMWs, smw) }} } -func WithStreamRecvMiddleware(smw streamx.StreamRecvMiddleware) Option { - return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.StreamXOptions.StreamRecvMWs = append(o.StreamXOptions.StreamRecvMWs, smw) +func WithStreamRecvMiddleware(smw streamx.StreamRecvMiddleware) internal_client.Option { + return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.StreamX.StreamRecvMWs = append(o.StreamX.StreamRecvMWs, smw) }} } -func WithStreamSendMiddleware(smw streamx.StreamSendMiddleware) Option { - return Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.StreamXOptions.StreamSendMWs = append(o.StreamXOptions.StreamSendMWs, smw) +func WithStreamSendMiddleware(smw streamx.StreamSendMiddleware) internal_client.Option { + return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.StreamX.StreamSendMWs = append(o.StreamX.StreamSendMWs, smw) }} } diff --git a/client/streamxclient/streamxcallopt/call_option.go b/client/streamxclient/streamxcallopt/call_option.go index bbd6a157cc..a40f6f2aed 100644 --- a/client/streamxclient/streamxcallopt/call_option.go +++ b/client/streamxclient/streamxcallopt/call_option.go @@ -20,25 +20,31 @@ import ( "context" ) +// StreamCloseCallback define close callback of the stream type StreamCloseCallback func() +// CallOptions define stream call level options type CallOptions struct { StreamCloseCallback []StreamCloseCallback } +// CallOption define stream call level option type CallOption struct { f func(o *CallOptions) } +// WithCallOption add call option type WithCallOption func(o *CallOption) type ctxKeyCallOptions struct{} +// NewCtxWithCallOptions register CallOptions into context func NewCtxWithCallOptions(ctx context.Context) (context.Context, *CallOptions) { copts := new(CallOptions) return context.WithValue(ctx, ctxKeyCallOptions{}, copts), copts } +// GetCallOptionsFromCtx get CallOptions from context func GetCallOptionsFromCtx(ctx context.Context) *CallOptions { v := ctx.Value(ctxKeyCallOptions{}) if v == nil { @@ -51,12 +57,14 @@ func GetCallOptionsFromCtx(ctx context.Context) *CallOptions { return copts } +// Apply call options func (copts *CallOptions) Apply(opts []CallOption) { for _, opt := range opts { opt.f(copts) } } +// WithStreamCloseCallback register StreamCloseCallback func WithStreamCloseCallback(callback StreamCloseCallback) CallOption { return CallOption{f: func(o *CallOptions) { o.StreamCloseCallback = append(o.StreamCloseCallback, callback) diff --git a/internal/client/option.go b/internal/client/option.go index 7452facd0b..86efeac653 100644 --- a/internal/client/option.go +++ b/internal/client/option.go @@ -22,7 +22,6 @@ import ( "time" "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/localsession/backup" "github.com/cloudwego/kitex/internal/configutil" @@ -121,8 +120,8 @@ type Options struct { // Context backup CtxBackupHandler backup.BackupHandler - Streaming stream.StreamingConfig - StreamXOptions streamx.ClientOptions + Streaming stream.StreamingConfig + StreamX StreamXOptions } // Apply applies all options. @@ -213,3 +212,11 @@ func (o *Options) InitRetryContainer() { o.CloseCallbacks = append(o.CloseCallbacks, o.RetryContainer.Close) } } + +// StreamXOptions define the client options +type StreamXOptions struct { + RecvTimeout time.Duration + StreamMWs []streamx.StreamMiddleware + StreamRecvMWs []streamx.StreamRecvMiddleware + StreamSendMWs []streamx.StreamSendMiddleware +} diff --git a/internal/server/option.go b/internal/server/option.go index 7bf3f21b7f..f7dc6f22aa 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -98,7 +98,7 @@ type Options struct { // Streaming Streaming stream.StreamingConfig // old version streaming API config - StreamX StreamXConfig // new version streaming API config + StreamX StreamXOptions // new version streaming API config RefuseTrafficWithoutServiceName bool EnableContextTimeout bool diff --git a/internal/server/streamx_config.go b/internal/server/streamx_config.go index 9890b9cdb4..43e07658db 100644 --- a/internal/server/streamx_config.go +++ b/internal/server/streamx_config.go @@ -18,7 +18,7 @@ package server import "github.com/cloudwego/kitex/pkg/streamx" -type StreamXConfig struct { +type StreamXOptions struct { StreamMiddlewares []streamx.StreamMiddleware StreamRecvMiddlewares []streamx.StreamRecvMiddleware StreamSendMiddlewares []streamx.StreamSendMiddleware diff --git a/pkg/remote/option.go b/pkg/remote/option.go index 573b17eb8a..7b78cc2569 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -148,5 +148,5 @@ type ClientOption struct { EnableConnPoolReporter bool - Provider interface{} // streamx.ClientProvider + Provider streamx.ClientProvider } diff --git a/pkg/remote/remotecli/stream.go b/pkg/remote/remotecli/stream.go index 3ecc05c0aa..1ea1df172d 100644 --- a/pkg/remote/remotecli/stream.go +++ b/pkg/remote/remotecli/stream.go @@ -37,10 +37,9 @@ func NewStream(ctx context.Context, ri rpcinfo.RPCInfo, handler remote.ClientTra } // streamx provider - clientProvider, ok := opt.Provider.(streamx.ClientProvider) - if ok { - // wrap client provider - clientProvider = streamx.NewClientProvider(clientProvider) + if opt.Provider != nil { + // wrap internal client provider + clientProvider := streamx.NewClientProvider(opt.Provider) cs, err := clientProvider.NewStream(ctx, ri) if err != nil { return nil, nil, err diff --git a/pkg/remote/remotesvr/server.go b/pkg/remote/remotesvr/server.go index 05c566976b..c3ae96aed5 100644 --- a/pkg/remote/remotesvr/server.go +++ b/pkg/remote/remotesvr/server.go @@ -21,7 +21,6 @@ import ( "net" "sync" - "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote" @@ -42,7 +41,7 @@ type server struct { } // NewServer creates a remote server. -func NewServer(opt *remote.ServerOption, inkHdlFunc endpoint.Endpoint, transHdlr remote.ServerTransHandler) (Server, error) { +func NewServer(opt *remote.ServerOption, transHdlr remote.ServerTransHandler) (Server, error) { transSvr := opt.TransServerFactory.NewTransServer(opt, transHdlr) s := &server{ opt: opt, diff --git a/pkg/remote/remotesvr/server_test.go b/pkg/remote/remotesvr/server_test.go index 078e50dfbd..779fb46909 100644 --- a/pkg/remote/remotesvr/server_test.go +++ b/pkg/remote/remotesvr/server_test.go @@ -17,7 +17,6 @@ package remotesvr import ( - "context" "errors" "net" "testing" @@ -53,11 +52,8 @@ func TestServerStart(t *testing.T) { Address: utils.NewNetAddr("tcp", "test"), TransServerFactory: mocks.NewMockTransServerFactory(transSvr), } - inkHdlrFunc := func(ctx context.Context, req, resp interface{}) (err error) { - return nil - } transHdrl := &mocks.MockSvrTransHandler{} - svr, err := NewServer(opt, inkHdlrFunc, transHdrl) + svr, err := NewServer(opt, transHdrl) test.Assert(t, err == nil, err) err = <-svr.Start() @@ -80,11 +76,8 @@ func TestServerStartListenErr(t *testing.T) { Address: utils.NewNetAddr("tcp", "test"), TransServerFactory: mocks.NewMockTransServerFactory(transSvr), } - inkHdlrFunc := func(ctx context.Context, req, resp interface{}) (err error) { - return nil - } transHdrl := &mocks.MockSvrTransHandler{} - svr, err := NewServer(opt, inkHdlrFunc, transHdrl) + svr, err := NewServer(opt, transHdrl) test.Assert(t, err == nil, err) err = <-svr.Start() diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index 5681365690..6a9d84ab42 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -1,5 +1,5 @@ /* - * Copyright 2021 CloudWeGo Authors + * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ import ( - OnActive - OnInactive - OnError -- GracefulShutdown: assert 方式使用 +- GracefulShutdown: used by type assert Other interface is used by trans pipeline */ diff --git a/pkg/streamx/client_options.go b/pkg/streamx/client_options.go index 9bded6a252..e93b077fd4 100644 --- a/pkg/streamx/client_options.go +++ b/pkg/streamx/client_options.go @@ -18,18 +18,9 @@ package streamx import ( "context" - "time" "github.com/cloudwego/kitex/pkg/stats" ) // EventHandler define stats event handler type EventHandler func(ctx context.Context, evt stats.Event, err error) - -// ClientOptions define the client options -type ClientOptions struct { - RecvTimeout time.Duration - StreamMWs []StreamMiddleware - StreamRecvMWs []StreamRecvMiddleware - StreamSendMWs []StreamSendMiddleware -} diff --git a/server/server.go b/server/server.go index d056299c07..cb68710aa6 100644 --- a/server/server.go +++ b/server/server.go @@ -256,7 +256,7 @@ func (s *server) Run() (err error) { return err } s.Lock() - s.svr, err = remotesvr.NewServer(s.opt.RemoteOpt, s.eps, transHdlr) + s.svr, err = remotesvr.NewServer(s.opt.RemoteOpt, transHdlr) s.Unlock() if err != nil { return err From 1431318f9e71b3cdb9e0758644cfe027d3cda071 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Mon, 11 Nov 2024 16:52:46 +0800 Subject: [PATCH 18/34] fix: mux transport leak --- internal/client/option.go | 2 +- .../ttstream/client_trans_pool_muxconn.go | 133 +++++++++--------- .../provider/ttstream/container/pipe.go | 57 ++++---- .../provider/ttstream/container/pipe_test.go | 39 +++++ pkg/streamx/provider/ttstream/transport.go | 6 +- pkg/streamx/streamx_user_test.go | 3 +- pkg/transmeta/ttheader.go | 2 +- 7 files changed, 138 insertions(+), 104 deletions(-) diff --git a/internal/client/option.go b/internal/client/option.go index 86efeac653..b34dfecb45 100644 --- a/internal/client/option.go +++ b/internal/client/option.go @@ -21,7 +21,6 @@ import ( "context" "time" - "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/localsession/backup" "github.com/cloudwego/kitex/internal/configutil" @@ -48,6 +47,7 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" + "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/transmeta" "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/kitex/pkg/warmup" diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go index ef44e3a07b..e560940ad5 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go @@ -18,14 +18,12 @@ package ttstream import ( "errors" - "fmt" "runtime" "sync" "sync/atomic" "time" "github.com/cloudwego/netpoll" - "golang.org/x/sync/singleflight" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" @@ -49,17 +47,19 @@ type muxConnTransList struct { cursor uint32 transports []*transport pool transPool - sf singleflight.Group } func newMuxConnTransList(size int, pool transPool) *muxConnTransList { tl := new(muxConnTransList) - if size == 0 { + if size <= 0 { size = runtime.GOMAXPROCS(0) } tl.size = size tl.transports = make([]*transport, size) tl.pool = pool + runtime.SetFinalizer(tl, func(tl *muxConnTransList) { + tl.Close() + }) return tl } @@ -76,38 +76,44 @@ func (tl *muxConnTransList) Close() { } func (tl *muxConnTransList) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { + // fast path idx := atomic.AddUint32(&tl.cursor, 1) % uint32(tl.size) tl.L.RLock() trans := tl.transports[idx] tl.L.RUnlock() - if trans != nil && trans.IsActive() { - return trans, nil + if trans != nil { + if trans.IsActive() { + return trans, nil + } + _ = trans.Close(nil) } - v, err, _ := tl.sf.Do(fmt.Sprintf("%d", idx), func() (interface{}, error) { - conn, err := dialer.DialConnection(network, addr, time.Second) - if err != nil { - return nil, err - } - trans := newTransport(clientTransport, sinfo, conn, tl.pool) - _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { - // peer close - _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("connection closed by peer"))) - return nil - }) - runtime.SetFinalizer(trans, func(trans *transport) { - // self close when not hold by user - _ = trans.Close(nil) - }) - tl.L.Lock() - tl.transports[idx] = trans + // slow path + tl.L.Lock() + trans = tl.transports[idx] + if trans != nil && trans.IsActive() { + // another goroutine already create the new transport tl.L.Unlock() return trans, nil - }) + } + // it may create more than tl.size transport if multi client try to get transport concurrently + conn, err := dialer.DialConnection(network, addr, time.Second) if err != nil { return nil, err } - trans = v.(*transport) + trans = newTransport(clientTransport, sinfo, conn, tl.pool) + _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { + // peer close + _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("connection closed by peer"))) + return nil + }) + runtime.SetFinalizer(trans, func(trans *transport) { + // self close when not hold by user + _ = trans.Close(nil) + }) + tl.transports[idx] = trans + tl.L.Unlock() + return trans, nil } @@ -121,56 +127,51 @@ type muxConnTransPool struct { config MuxConnConfig pool sync.Map // addr:*muxConnTransList activity sync.Map // addr:lastActive - sflight singleflight.Group - cleanerOnce sync.Once + cleanerOnce int32 } func (p *muxConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) { v, ok := p.pool.Load(addr) - if ok { - return v.(*muxConnTransList).Get(sinfo, network, addr) - } - - v, err, _ = p.sflight.Do(addr, func() (interface{}, error) { - transList := newMuxConnTransList(p.config.PoolSize, p) - p.pool.Store(addr, transList) - return transList, nil - }) - if err != nil { - return nil, err + if !ok { + // multi concurrent Get should get the same TransList object + v, _ = p.pool.LoadOrStore(addr, newMuxConnTransList(p.config.PoolSize, p)) } return v.(*muxConnTransList).Get(sinfo, network, addr) } func (p *muxConnTransPool) Put(trans *transport) { p.activity.Store(trans.conn.RemoteAddr().String(), time.Now()) - p.cleanerOnce.Do(func() { - internal := p.config.MaxIdleTimeout - if internal == 0 { - return - } - go func() { - for { - now := time.Now() - count := 0 - p.activity.Range(func(addr, value interface{}) bool { - count++ - lastActive := value.(time.Time) - if lastActive.IsZero() || now.Sub(lastActive) < p.config.MaxIdleTimeout { - return true - } - v, _ := p.pool.Load(addr) - if v == nil { - return true - } - transList := v.(*muxConnTransList) - p.pool.Delete(addr) - p.activity.Delete(addr) - transList.Close() + + idleTimeout := p.config.MaxIdleTimeout + if idleTimeout == 0 { + return + } + if !atomic.CompareAndSwapInt32(&p.cleanerOnce, 0, 1) { + return + } + // start cleaning background goroutine + go func() { + for { + now := time.Now() + count := 0 + p.activity.Range(func(key, value interface{}) bool { + addr := key.(string) + count++ + lastActive := value.(time.Time) + idleTime := now.Sub(lastActive) + if lastActive.IsZero() || idleTime < 0 || now.Sub(lastActive) < idleTimeout { return true - }) - time.Sleep(internal) - } - }() - }) + } + // clean transport + v, _ := p.pool.Load(addr) + if v == nil { + return true + } + p.pool.Delete(addr) + p.activity.Delete(addr) + return true + }) + time.Sleep(idleTimeout) + } + }() } diff --git a/pkg/streamx/provider/ttstream/container/pipe.go b/pkg/streamx/provider/ttstream/container/pipe.go index a2c8254354..7b8adaf875 100644 --- a/pkg/streamx/provider/ttstream/container/pipe.go +++ b/pkg/streamx/provider/ttstream/container/pipe.go @@ -26,10 +26,9 @@ import ( type pipeState = int32 const ( - pipeStateInactive pipeState = 0 - pipeStateActive pipeState = 1 - pipeStateClosed pipeState = 2 - pipeStateCanceled pipeState = 3 + pipeStateActive pipeState = 0 + pipeStateClosed pipeState = 1 + pipeStateCanceled pipeState = 2 ) var ( @@ -44,20 +43,22 @@ var ( // Pipe implement a queue that never block on Write but block on Read if there is nothing to read type Pipe[Item any] struct { queue *Queue[Item] - trigger chan struct{} state pipeState + trigger chan struct{} } func NewPipe[Item any]() *Pipe[Item] { p := new(Pipe[Item]) p.queue = NewQueue[Item]() p.trigger = make(chan struct{}, 1) + p.state = pipeStateActive return p } // Read will block if there is nothing to read func (p *Pipe[Item]) Read(ctx context.Context, items []Item) (n int, err error) { READ: + // check readable items for i := 0; i < len(items); i++ { val, ok := p.queue.Get() if !ok { @@ -70,6 +71,12 @@ READ: return n, nil } + // check state + state := atomic.LoadInt32(&p.state) + if err = stateErrors[state]; err != nil { + return 0, err + } + // no data to read, waiting writes for { if ctx.Done() != nil { @@ -81,27 +88,11 @@ READ: } else { <-p.trigger } - - if p.queue.Size() == 0 { - err = stateErrors[atomic.LoadInt32(&p.state)] - if err != nil { - return 0, err - } - } goto READ } } func (p *Pipe[Item]) Write(ctx context.Context, items ...Item) (err error) { - if !atomic.CompareAndSwapInt32(&p.state, pipeStateInactive, pipeStateActive) && atomic.LoadInt32(&p.state) != pipeStateActive { - err = stateErrors[atomic.LoadInt32(&p.state)] - if err != nil { - return err - } - // never happen error - return fmt.Errorf("unknown state error") - } - for _, item := range items { p.queue.Add(item) } @@ -114,21 +105,21 @@ func (p *Pipe[Item]) Write(ctx context.Context, items ...Item) (err error) { } func (p *Pipe[Item]) Close() { - if atomic.LoadInt32(&p.state) != pipeStateClosed { - atomic.StoreInt32(&p.state, pipeStateClosed) - select { - case p.trigger <- struct{}{}: - default: - } + if !atomic.CompareAndSwapInt32(&p.state, pipeStateActive, pipeStateClosed) { + return + } + select { + case p.trigger <- struct{}{}: + default: } } func (p *Pipe[Item]) Cancel() { - if atomic.LoadInt32(&p.state) != pipeStateCanceled { - atomic.StoreInt32(&p.state, pipeStateCanceled) - select { - case p.trigger <- struct{}{}: - default: - } + if !atomic.CompareAndSwapInt32(&p.state, pipeStateActive, pipeStateCanceled) { + return + } + select { + case p.trigger <- struct{}{}: + default: } } diff --git a/pkg/streamx/provider/ttstream/container/pipe_test.go b/pkg/streamx/provider/ttstream/container/pipe_test.go index 7475fd13d8..cfc3f4f3f4 100644 --- a/pkg/streamx/provider/ttstream/container/pipe_test.go +++ b/pkg/streamx/provider/ttstream/container/pipe_test.go @@ -18,9 +18,13 @@ package container import ( "context" + "errors" "io" + "runtime" "sync" + "sync/atomic" "testing" + "time" ) func TestPipeline(t *testing.T) { @@ -62,6 +66,41 @@ func TestPipeline(t *testing.T) { } } +func TestPipelineWriteCloseAndRead(t *testing.T) { + ctx := context.Background() + pipe := NewPipe[int]() + pipeSize := 10 + var pipeRead int32 + var readWG sync.WaitGroup + readWG.Add(1) + go func() { + defer readWG.Done() + readBuf := make([]int, 1) + for { + n, err := pipe.Read(ctx, readBuf) + if err != nil { + if !errors.Is(err, ErrPipeEOF) { + t.Errorf("un except err: %v", err) + } + break + } + atomic.AddInt32(&pipeRead, int32(n)) + } + }() + time.Sleep(time.Millisecond * 10) // let read goroutine start first + for i := 0; i < pipeSize; i++ { + err := pipe.Write(ctx, i) + if err != nil { + t.Error(err) + } + } + for atomic.LoadInt32(&pipeRead) != int32(pipeSize) { + runtime.Gosched() + } + pipe.Close() + readWG.Wait() +} + func BenchmarkPipeline(b *testing.B) { ctx := context.Background() pipe := NewPipe[int]() diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index c329629db4..a237fae5ac 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -126,14 +126,16 @@ func (t *transport) Close(exception error) (err error) { return nil } klog.Debugf("transport[%d-%s] is closing", t.kind, t.Addr()) - t.spipe.Close() - t.fpipe.Close() + // send trailer first t.streams.Range(func(key, value any) bool { s := value.(*stream) _ = s.closeSend(exception) _ = s.closeRecv(exception) return true }) + // then close stream and frame pipes + t.spipe.Close() + t.fpipe.Close() return err } diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index d9bf8a68d6..3f446e07cb 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -61,7 +61,7 @@ func init() { providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_LongConn", ClientProvider: cp, ServerProvider: sp}) cp, _ = ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientShortConnPool()) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_ShortConn", ClientProvider: cp, ServerProvider: sp}) - cp, _ = ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientMuxConnPool(ttstream.MuxConnConfig{PoolSize: 8, MaxIdleTimeout: time.Millisecond * 1000})) + cp, _ = ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientMuxConnPool(ttstream.MuxConnConfig{PoolSize: 8, MaxIdleTimeout: time.Millisecond * 100})) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_Mux", ClientProvider: cp, ServerProvider: sp}) } @@ -720,6 +720,7 @@ func TestStreamingGoroutineLeak(t *testing.T) { for i := 0; i < streams; i++ { streamList[i] = nil } + streamList = nil for runtime.NumGoroutine() > ngBefore { t.Logf("ngCurrent=%d > ngBefore=%d", runtime.NumGoroutine(), ngBefore) runtime.GC() diff --git a/pkg/transmeta/ttheader.go b/pkg/transmeta/ttheader.go index 416c477e13..3b18537a58 100644 --- a/pkg/transmeta/ttheader.go +++ b/pkg/transmeta/ttheader.go @@ -126,7 +126,7 @@ func (ch *clientTTHeaderHandler) ReadMeta(ctx context.Context, msg remote.Messag } func ParseBizStatusErr(strInfo map[string]string) (kerrors.BizStatusErrorIface, error) { - if code, err := strconv.Atoi(strInfo[bizStatus]); err == nil && code != 0 { + if code, err := strconv.ParseInt(strInfo[bizStatus], 10, 32); err == nil && code != 0 { if bizExtra := strInfo[bizExtra]; bizExtra != "" { extra, err := utils.JSONStr2Map(bizExtra) if err != nil { From 20592bdd4ca43b0b986059b35cb69b80765eae1f Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Mon, 11 Nov 2024 17:10:10 +0800 Subject: [PATCH 19/34] chore: fix stream recv timeout test --- pkg/streamx/streamx_user_test.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index 3f446e07cb..028b11d1a0 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -552,10 +552,6 @@ func TestStreamingBasic(t *testing.T) { t.Logf("=== Timeout by Ctx ===") ctx, bs, err := cli.BidiStream(octx) test.Assert(t, err == nil, err) - req = new(Request) - req.Message = string(make([]byte, 1024)) - err = bs.Send(ctx, req) - test.Assert(t, err == nil, err) nctx, cancel := context.WithCancel(ctx) cancel() _, err = bs.Recv(nctx) @@ -574,10 +570,6 @@ func TestStreamingBasic(t *testing.T) { ) ctx, bs, err = cli.BidiStream(octx) test.Assert(t, err == nil, err) - req = new(Request) - req.Message = string(make([]byte, 1024)) - err = bs.Send(ctx, req) - test.Assert(t, err == nil, err) _, err = bs.Recv(ctx) test.Assert(t, err != nil, err) t.Logf("recv timeout error: %v", err) @@ -720,7 +712,6 @@ func TestStreamingGoroutineLeak(t *testing.T) { for i := 0; i < streams; i++ { streamList[i] = nil } - streamList = nil for runtime.NumGoroutine() > ngBefore { t.Logf("ngCurrent=%d > ngBefore=%d", runtime.NumGoroutine(), ngBefore) runtime.GC() From c742bab24b509941f7e4ba71ee5e0e50551ab2e4 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Mon, 11 Nov 2024 17:49:02 +0800 Subject: [PATCH 20/34] chore: reduce streamx log --- .github/workflows/tests.yml | 2 +- pkg/streamx/streamx_user_service_test.go | 20 ++++++++++---------- pkg/streamx/streamx_user_test.go | 2 -- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e71e772198..90b0c4a907 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,7 @@ jobs: go-version: ${{ matrix.go }} cache: false # don't use cache for self-hosted runners - name: Unit Test - run: go test -v -race ./... + run: go test -race ./... codegen-test: runs-on: ubuntu-latest diff --git a/pkg/streamx/streamx_user_service_test.go b/pkg/streamx/streamx_user_service_test.go index c48dc5f212..59bde4e5f2 100644 --- a/pkg/streamx/streamx_user_service_test.go +++ b/pkg/streamx/streamx_user_service_test.go @@ -62,13 +62,13 @@ func (si *testService) setHeaderAndTrailer(stream streamx.ServerStreamMetadata) func (si *testService) PingPong(ctx context.Context, req *Request) (*Response, error) { resp := &Response{Type: req.Type, Message: req.Message} - klog.Infof("Server PingPong: req={%v} resp={%v}", req, resp) + klog.Debugf("Server PingPong: req={%v} resp={%v}", req, resp) return resp, nil } func (si *testService) Unary(ctx context.Context, req *Request) (*Response, error) { resp := &Response{Type: req.Type, Message: req.Message} - klog.Infof("Server Unary: req={%v} resp={%v}", req, resp) + klog.Debugf("Server Unary: req={%v} resp={%v}", req, resp) return resp, nil } @@ -76,8 +76,8 @@ func (si *testService) ClientStream(ctx context.Context, stream streamx.ClientStreamingServer[Request, Response], ) (*Response, error) { var msg string - klog.Infof("Server ClientStream start") - defer klog.Infof("Server ClientStream end") + klog.Debugf("Server ClientStream start") + defer klog.Debugf("Server ClientStream end") if err := si.setHeaderAndTrailer(stream); err != nil { return nil, err @@ -100,7 +100,7 @@ func (si *testService) ClientStream(ctx context.Context, func (si *testService) ServerStream(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response], ) error { - klog.Infof("Server ServerStream: req={%v}", req) + klog.Debugf("Server ServerStream: req={%v}", req) if err := si.setHeaderAndTrailer(stream); err != nil { return err @@ -114,7 +114,7 @@ func (si *testService) ServerStream(ctx context.Context, req *Request, if err != nil { return err } - klog.Infof("Server ServerStream: send resp={%v}", resp) + klog.Debugf("Server ServerStream: send resp={%v}", resp) } return nil } @@ -165,7 +165,7 @@ func buildErr(req *Request) error { func (si *testService) UnaryWithErr(ctx context.Context, req *Request) (*Response, error) { err := buildErr(req) - klog.Infof("Server UnaryWithErr: req={%v} err={%v}", req, err) + klog.Debugf("Server UnaryWithErr: req={%v} err={%v}", req, err) return nil, err } @@ -176,13 +176,13 @@ func (si *testService) ClientStreamWithErr(ctx context.Context, stream streamx.C return nil, err } err = buildErr(req) - klog.Infof("Server ClientStreamWithErr: req={%v} err={%v}", req, err) + klog.Debugf("Server ClientStreamWithErr: req={%v} err={%v}", req, err) return nil, err } func (si *testService) ServerStreamWithErr(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error { err := buildErr(req) - klog.Infof("Server ServerStreamWithErr: req={%v} err={%v}", req, err) + klog.Debugf("Server ServerStreamWithErr: req={%v} err={%v}", req, err) return err } @@ -193,6 +193,6 @@ func (si *testService) BidiStreamWithErr(ctx context.Context, stream streamx.Bid return err } err = buildErr(req) - klog.Infof("Server BidiStreamWithErr: req={%v} err={%v}", req, err) + klog.Debugf("Server BidiStreamWithErr: req={%v} err={%v}", req, err) return err } diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index 028b11d1a0..6056e7d065 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -37,7 +37,6 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/transport" - "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" "github.com/cloudwego/kitex/client/streamxclient" @@ -69,7 +68,6 @@ func TestMain(m *testing.M) { go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) }() - klog.SetLevel(klog.LevelDebug) m.Run() } From 4fcf8b48baceef5d17bad293de468e2c9255f8c0 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Tue, 12 Nov 2024 14:15:48 +0800 Subject: [PATCH 21/34] fix: start goroutine with recover --- .../ttstream/client_trans_pool_muxconn.go | 6 ++- .../ttstream/container/object_pool.go | 5 +- pkg/streamx/provider/ttstream/transport.go | 51 ++++++++++--------- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go index e560940ad5..5d75356ef2 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go @@ -17,6 +17,7 @@ package ttstream import ( + "context" "errors" "runtime" "sync" @@ -25,6 +26,7 @@ import ( "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) @@ -150,7 +152,7 @@ func (p *muxConnTransPool) Put(trans *transport) { return } // start cleaning background goroutine - go func() { + gofunc.RecoverGoFuncWithInfo(context.Background(), func() { for { now := time.Now() count := 0 @@ -173,5 +175,5 @@ func (p *muxConnTransPool) Put(trans *transport) { }) time.Sleep(idleTimeout) } - }() + }, gofunc.NewBasicInfo("", trans.Addr().String())) } diff --git a/pkg/streamx/provider/ttstream/container/object_pool.go b/pkg/streamx/provider/ttstream/container/object_pool.go index 6ac599cbdb..bf5ac770d9 100644 --- a/pkg/streamx/provider/ttstream/container/object_pool.go +++ b/pkg/streamx/provider/ttstream/container/object_pool.go @@ -17,9 +17,12 @@ package container import ( + "context" "sync" "sync/atomic" "time" + + "github.com/cloudwego/kitex/pkg/gofunc" ) type Object interface { @@ -35,7 +38,7 @@ func NewObjectPool(idleTimeout time.Duration) *ObjectPool { s := new(ObjectPool) s.idleTimeout = idleTimeout s.objects = make(map[string]*Stack[objectItem]) - go s.cleaning() + gofunc.RecoverGoFuncWithInfo(context.Background(), s.cleaning, gofunc.NewBasicInfo("", "")) return s } diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index a237fae5ac..e3dab54cb8 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -24,11 +24,11 @@ import ( "net" "sync" "sync/atomic" - "time" "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" @@ -63,9 +63,8 @@ type transport struct { } func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection, pool transPool) *transport { - // stream max idle session is 10 minutes. // TODO: let it configurable - _ = conn.SetReadTimeout(time.Minute * 10) + _ = conn.SetReadTimeout(0) t := &transport{ kind: kind, sinfo: sinfo, @@ -77,32 +76,38 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne fpipe: container.NewPipe[*Frame](), closedTrigger: make(chan struct{}, 2), } - go func() { + addr := "" + if t.Addr() != nil { + addr = t.Addr().String() + } + gofunc.RecoverGoFuncWithInfo(context.Background(), func() { + var err error defer func() { + if err != nil { + if !isIgnoreError(err) { + klog.Warnf("transport[%d-%s] loop read err: %v", t.kind, t.Addr(), err) + } + // if connection is closed by peer, loop read should return ErrConnClosed error, + // so we should close transport here + _ = t.Close(err) + } t.closedTrigger <- struct{}{} }() - err := t.loopRead() - if err != nil { - if !isIgnoreError(err) { - klog.Warnf("transport[%d-%s] loop read err: %v", t.kind, t.Addr(), err) - } - // if connection is closed by peer, loop read should return ErrConnClosed error, - // so we should close transport here - _ = t.Close(err) - } - }() - go func() { + err = t.loopRead() + }, gofunc.NewBasicInfo(sinfo.ServiceName, addr)) + gofunc.RecoverGoFuncWithInfo(context.Background(), func() { + var err error defer func() { + if err != nil { + if !isIgnoreError(err) { + klog.Warnf("transport[%d-%s] loop write err: %v", t.kind, t.Addr(), err) + } + _ = t.Close(err) + } t.closedTrigger <- struct{}{} }() - err := t.loopWrite() - if err != nil { - if !isIgnoreError(err) { - klog.Warnf("transport[%d-%s] loop write err: %v", t.kind, t.Addr(), err) - } - _ = t.Close(err) - } - }() + err = t.loopWrite() + }, gofunc.NewBasicInfo(sinfo.ServiceName, addr)) return t } From 29790d15915cc088837f6d8181f453948b6b8664 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Tue, 12 Nov 2024 14:40:56 +0800 Subject: [PATCH 22/34] chore: ut wait server stream done --- pkg/streamx/streamx_user_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index 6056e7d065..7a2c4e17ee 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -127,6 +127,11 @@ func TestStreamingBasic(t *testing.T) { atomic.StoreInt32(&serverRecvCount, 0) atomic.StoreInt32(&serverSendCount, 0) } + waitServerStreamDone := func(streamCount int) { + for atomic.LoadInt32(&serverStreamCount) < int32(streamCount) { + runtime.Gosched() + } + } addr, svr, err := NewTestServer( new(testService), server.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { @@ -324,6 +329,7 @@ func TestStreamingBasic(t *testing.T) { test.Assert(t, req.Message == res.Message, res.Message) }() } + waitServerStreamDone(concurrency) wg.Wait() test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) @@ -357,6 +363,7 @@ func TestStreamingBasic(t *testing.T) { testHeaderAndTrailer(t, cs) }() } + waitServerStreamDone(concurrency) wg.Wait() untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) @@ -393,6 +400,7 @@ func TestStreamingBasic(t *testing.T) { testHeaderAndTrailer(t, ss) }() } + waitServerStreamDone(concurrency) wg.Wait() untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) @@ -443,6 +451,7 @@ func TestStreamingBasic(t *testing.T) { testHeaderAndTrailer(t, bs) }() } + waitServerStreamDone(concurrency) wg.Wait() untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) untilEqual(t, &clientStreamCount, int32(concurrency), time.Second) From 509060e538e2b56d283774ba6e76a2a979fb51b4 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Wed, 13 Nov 2024 13:10:20 +0800 Subject: [PATCH 23/34] fix: client StreamX API ignore req args --- client/client_streamx.go | 10 +++++----- client/streamxclient/client_gen.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/client/client_streamx.go b/client/client_streamx.go index ac1369a847..b865f70c4f 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -25,11 +25,11 @@ import ( ) type StreamX interface { - NewStream(ctx context.Context, method string, req any, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) + NewStream(ctx context.Context, method string, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) } // NewStream create stream for streamx mode -func (kc *kClient) NewStream(ctx context.Context, method string, req any, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) { +func (kc *kClient) NewStream(ctx context.Context, method string, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) { if !kc.inited { panic("client not initialized") } @@ -61,10 +61,10 @@ func (kc *kClient) NewStream(ctx context.Context, method string, req any, stream msargs.SetStreamRecvMiddleware(kc.sxStreamRecvMW) msargs.SetStreamSendMiddleware(kc.sxStreamSendMW) } - // put streamArgs into response arg + // with streamx mode, req is nil and resp is streamArgs // it's an ugly trick but if we don't want to refactor too much, - // this is the only way to compatible with current endpoint design - err = kc.sEps(ctx, req, streamArgs) + // this is the only way to compatible with current endpoint API. + err = kc.sEps(ctx, nil, streamArgs) if err != nil { return nil, nil, err } diff --git a/client/streamxclient/client_gen.go b/client/streamxclient/client_gen.go index fedc42b12d..781e4e2325 100644 --- a/client/streamxclient/client_gen.go +++ b/client/streamxclient/client_gen.go @@ -42,7 +42,7 @@ func InvokeStream[Req, Res any]( } // NewStream should register client middlewares into stream Args - ctx, cs, err := cli.NewStream(ctx, method, req, streamArgs, callOptions...) + ctx, cs, err := cli.NewStream(ctx, method, streamArgs, callOptions...) if err != nil { return nil, nil, err } From 547778596f5af6d1738de20ec4bba6a17b991dde Mon Sep 17 00:00:00 2001 From: Joway Date: Fri, 15 Nov 2024 14:06:38 +0800 Subject: [PATCH 24/34] feat: refactor streamx APIs (#1612) --- client/client.go | 3 ++ client/client_streamx.go | 4 --- client/streamxclient/client_gen.go | 10 ++++-- client/streamxclient/client_option.go | 6 ---- internal/streamx/stream.go | 27 +++++++++++++++ internal/streamx/streamxclient/client.go | 39 ++++++++++++++++++++++ internal/streamx/streamxserver/server.go | 32 ++++++++++++++++++ pkg/remote/trans/streamx/server_handler.go | 12 ++++--- pkg/rpcinfo/interface.go | 1 + pkg/rpcinfo/invocation.go | 25 ++++++++++---- pkg/streamx/stream.go | 3 -- pkg/streamx/streamx_gen_service_test.go | 21 ++++++------ pkg/streamx/streamx_user_test.go | 31 ++++++++++------- server/streamxserver/option.go | 6 ---- 14 files changed, 166 insertions(+), 54 deletions(-) create mode 100644 internal/streamx/stream.go create mode 100644 internal/streamx/streamxclient/client.go create mode 100644 internal/streamx/streamxserver/server.go diff --git a/client/client.go b/client/client.go index c2fdf25eb4..eeb7ec8a20 100644 --- a/client/client.go +++ b/client/client.go @@ -737,6 +737,9 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf rpcStats.ImmutableView(), ) + if mi != nil { + ri.Invocation().(rpcinfo.InvocationSetter).SetStreamingMode(mi.StreamingMode()) + } if fromMethod := ctx.Value(consts.CtxKeyMethod); fromMethod != nil { rpcinfo.AsMutableEndpointInfo(ri.From()).SetMethod(fromMethod.(string)) } diff --git a/client/client_streamx.go b/client/client_streamx.go index b865f70c4f..16093c1980 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -24,10 +24,6 @@ import ( "github.com/cloudwego/kitex/pkg/streamx" ) -type StreamX interface { - NewStream(ctx context.Context, method string, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) -} - // NewStream create stream for streamx mode func (kc *kClient) NewStream(ctx context.Context, method string, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) { if !kc.inited { diff --git a/client/streamxclient/client_gen.go b/client/streamxclient/client_gen.go index 781e4e2325..b237aacedf 100644 --- a/client/streamxclient/client_gen.go +++ b/client/streamxclient/client_gen.go @@ -18,16 +18,18 @@ package streamxclient import ( "context" + "errors" "github.com/cloudwego/kitex/client" "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/internal/streamx/streamxclient" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" ) // InvokeStream create a new client stream and wrapped related middlewares func InvokeStream[Req, Res any]( - ctx context.Context, cli client.StreamX, smode serviceinfo.StreamingMode, method string, + ctx context.Context, cli client.Client, smode serviceinfo.StreamingMode, method string, req *Req, res *Res, callOptions ...streamxcallopt.CallOption, ) (context.Context, *streamx.GenericClientStream[Req, Res], error) { reqArgs, resArgs := streamx.NewStreamReqArgs(nil), streamx.NewStreamResArgs(nil) @@ -41,8 +43,12 @@ func InvokeStream[Req, Res any]( resArgs.SetRes(res) } + scli, ok := cli.(streamxclient.StreamXClient) + if !ok { + return nil, nil, errors.New("current client is not support streamx interface") + } // NewStream should register client middlewares into stream Args - ctx, cs, err := cli.NewStream(ctx, method, streamArgs, callOptions...) + ctx, cs, err := scli.NewStream(ctx, method, streamArgs, callOptions...) if err != nil { return nil, nil, err } diff --git a/client/streamxclient/client_option.go b/client/streamxclient/client_option.go index 7005622c43..9c55c3fdcb 100644 --- a/client/streamxclient/client_option.go +++ b/client/streamxclient/client_option.go @@ -36,12 +36,6 @@ func WithStreamRecvTimeout(timeout time.Duration) internal_client.Option { }} } -func WithStreamMiddleware(smw streamx.StreamMiddleware) internal_client.Option { - return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { - o.StreamX.StreamMWs = append(o.StreamX.StreamMWs, smw) - }} -} - func WithStreamRecvMiddleware(smw streamx.StreamRecvMiddleware) internal_client.Option { return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { o.StreamX.StreamRecvMWs = append(o.StreamX.StreamRecvMWs, smw) diff --git a/internal/streamx/stream.go b/internal/streamx/stream.go new file mode 100644 index 0000000000..af94d9d424 --- /dev/null +++ b/internal/streamx/stream.go @@ -0,0 +1,27 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamx + +import ( + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +type StreamInfo interface { + Service() string + Method() string + Mode() serviceinfo.StreamingMode +} diff --git a/internal/streamx/streamxclient/client.go b/internal/streamx/streamxclient/client.go new file mode 100644 index 0000000000..3bd833bf6d --- /dev/null +++ b/internal/streamx/streamxclient/client.go @@ -0,0 +1,39 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamxclient + +import ( + "context" + + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + internal_client "github.com/cloudwego/kitex/internal/client" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/utils" +) + +// StreamXClient implement a streamx interface client +type StreamXClient interface { + NewStream(ctx context.Context, method string, streamArgs streamx.StreamArgs, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStream, error) +} + +// WithStreamMiddleware currently is not open for users to use for now +// so it's include in internal package +func WithStreamMiddleware(smw streamx.StreamMiddleware) internal_client.Option { + return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { + o.StreamX.StreamMWs = append(o.StreamX.StreamMWs, smw) + }} +} diff --git a/internal/streamx/streamxserver/server.go b/internal/streamx/streamxserver/server.go new file mode 100644 index 0000000000..1683e0ba93 --- /dev/null +++ b/internal/streamx/streamxserver/server.go @@ -0,0 +1,32 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package streamxserver + +import ( + internal_server "github.com/cloudwego/kitex/internal/server" + "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/kitex/server" +) + +// WithStreamMiddleware currently is not open for users to use for now +// so it's include in internal package +func WithStreamMiddleware(mw streamx.StreamMiddleware) server.Option { + return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { + o.StreamX.StreamMiddlewares = append(o.StreamX.StreamMiddlewares, mw) + }} +} diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index 6a9d84ab42..92fe2193e1 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -25,6 +25,7 @@ import ( "sync" "time" + istreamx "github.com/cloudwego/kitex/internal/streamx" "github.com/cloudwego/kitex/internal/wpool" "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/klog" @@ -153,10 +154,13 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream }() ink := ri.Invocation().(rpcinfo.InvocationSetter) - ink.SetServiceName(ss.Service()) - ink.SetMethodName(ss.Method()) - if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil { - _ = mutableTo.SetMethod(ss.Method()) + if si, ok := ss.(istreamx.StreamInfo); ok { + ink.SetServiceName(si.Service()) + ink.SetMethodName(si.Method()) + ink.SetStreamingMode(si.Mode()) + if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil { + _ = mutableTo.SetMethod(si.Method()) + } } ctx = t.startTracer(ctx, ri) diff --git a/pkg/rpcinfo/interface.go b/pkg/rpcinfo/interface.go index c8249e8605..2f5eb4a765 100644 --- a/pkg/rpcinfo/interface.go +++ b/pkg/rpcinfo/interface.go @@ -92,6 +92,7 @@ type Invocation interface { PackageName() string ServiceName() string MethodName() string + StreamingMode() serviceinfo.StreamingMode SeqID() int32 BizStatusErr() kerrors.BizStatusErrorIface Extra(key string) interface{} diff --git a/pkg/rpcinfo/invocation.go b/pkg/rpcinfo/invocation.go index cf6dc0165f..4033547e23 100644 --- a/pkg/rpcinfo/invocation.go +++ b/pkg/rpcinfo/invocation.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) var ( @@ -39,18 +40,20 @@ type InvocationSetter interface { SetPackageName(name string) SetServiceName(name string) SetMethodName(name string) + SetStreamingMode(mode serviceinfo.StreamingMode) SetSeqID(seqID int32) SetBizStatusErr(err kerrors.BizStatusErrorIface) SetExtra(key string, value interface{}) Reset() } type invocation struct { - packageName string - serviceName string - methodName string - seqID int32 - bizErr kerrors.BizStatusErrorIface - extra map[string]interface{} + packageName string + serviceName string + methodName string + streamingMode serviceinfo.StreamingMode + seqID int32 + bizErr kerrors.BizStatusErrorIface + extra map[string]interface{} } // NewInvocation creates a new Invocation with the given service, method and optional package. @@ -121,6 +124,16 @@ func (i *invocation) SetMethodName(name string) { i.methodName = name } +// StreamingMode implements the Invocation interface. +func (i *invocation) StreamingMode() serviceinfo.StreamingMode { + return i.streamingMode +} + +// SetStreamingMode implements the InvocationSetter interface. +func (i *invocation) SetStreamingMode(mode serviceinfo.StreamingMode) { + i.streamingMode = mode +} + // BizStatusErr implements the Invocation interface. func (i *invocation) BizStatusErr() kerrors.BizStatusErrorIface { return i.bizErr diff --git a/pkg/streamx/stream.go b/pkg/streamx/stream.go index 5c30d399d1..e54df39b58 100644 --- a/pkg/streamx/stream.go +++ b/pkg/streamx/stream.go @@ -88,9 +88,6 @@ const ( // Stream define stream APIs type Stream interface { - Mode() StreamingMode - Service() string - Method() string SendMsg(ctx context.Context, m any) error RecvMsg(ctx context.Context, m any) error } diff --git a/pkg/streamx/streamx_gen_service_test.go b/pkg/streamx/streamx_gen_service_test.go index c267d6afd4..bf2828fd6d 100644 --- a/pkg/streamx/streamx_gen_service_test.go +++ b/pkg/streamx/streamx_gen_service_test.go @@ -195,7 +195,7 @@ func NewClient(destService string, opts ...client.Option) (TestServiceClient, er if err != nil { return nil, err } - kc := &kClient{caller: cli, streamer: cli.(client.StreamX)} + kc := &kClient{caller: cli} return kc, nil } @@ -237,8 +237,7 @@ type TestServiceClient interface { var _ TestServiceClient = (*kClient)(nil) type kClient struct { - caller client.Client - streamer client.StreamX + caller client.Client } func (c *kClient) PingPong(ctx context.Context, req *Request) (r *Response, err error) { @@ -254,7 +253,7 @@ func (c *kClient) PingPong(ctx context.Context, req *Request) (r *Response, err func (c *kClient) Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { res := new(Response) _, _, err := streamxclient.InvokeStream[Request, Response]( - ctx, c.streamer, serviceinfo.StreamingUnary, "Unary", req, res, callOptions...) + ctx, c.caller, serviceinfo.StreamingUnary, "Unary", req, res, callOptions...) if err != nil { return nil, err } @@ -265,27 +264,27 @@ func (c *kClient) ClientStream(ctx context.Context, callOptions ...streamxcallop context.Context, streamx.ClientStreamingClient[Request, Response], error, ) { return streamxclient.InvokeStream[Request, Response]( - ctx, c.streamer, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) + ctx, c.caller, serviceinfo.StreamingClient, "ClientStream", nil, nil, callOptions...) } func (c *kClient) ServerStream(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( context.Context, streamx.ServerStreamingClient[Response], error, ) { return streamxclient.InvokeStream[Request, Response]( - ctx, c.streamer, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) + ctx, c.caller, serviceinfo.StreamingServer, "ServerStream", req, nil, callOptions...) } func (c *kClient) BidiStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( context.Context, streamx.BidiStreamingClient[Request, Response], error, ) { return streamxclient.InvokeStream[Request, Response]( - ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) + ctx, c.caller, serviceinfo.StreamingBidirectional, "BidiStream", nil, nil, callOptions...) } func (c *kClient) UnaryWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (*Response, error) { res := new(Response) _, _, err := streamxclient.InvokeStream[Request, Response]( - ctx, c.streamer, serviceinfo.StreamingUnary, "UnaryWithErr", req, res, callOptions...) + ctx, c.caller, serviceinfo.StreamingUnary, "UnaryWithErr", req, res, callOptions...) if err != nil { return nil, err } @@ -296,19 +295,19 @@ func (c *kClient) ClientStreamWithErr(ctx context.Context, callOptions ...stream context.Context, streamx.ClientStreamingClient[Request, Response], error, ) { return streamxclient.InvokeStream[Request, Response]( - ctx, c.streamer, serviceinfo.StreamingClient, "ClientStreamWithErr", nil, nil, callOptions...) + ctx, c.caller, serviceinfo.StreamingClient, "ClientStreamWithErr", nil, nil, callOptions...) } func (c *kClient) ServerStreamWithErr(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) ( context.Context, streamx.ServerStreamingClient[Response], error, ) { return streamxclient.InvokeStream[Request, Response]( - ctx, c.streamer, serviceinfo.StreamingServer, "ServerStreamWithErr", req, nil, callOptions...) + ctx, c.caller, serviceinfo.StreamingServer, "ServerStreamWithErr", req, nil, callOptions...) } func (c *kClient) BidiStreamWithErr(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( context.Context, streamx.BidiStreamingClient[Request, Response], error, ) { return streamxclient.InvokeStream[Request, Response]( - ctx, c.streamer, serviceinfo.StreamingBidirectional, "BidiStreamWithErr", nil, nil, callOptions...) + ctx, c.caller, serviceinfo.StreamingBidirectional, "BidiStreamWithErr", nil, nil, callOptions...) } diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index 7a2c4e17ee..0f87a9e6ca 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -33,17 +33,17 @@ import ( "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/client" - "github.com/cloudwego/kitex/pkg/endpoint" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/transport" - - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" - "github.com/cloudwego/kitex/client/streamxclient" + istreamxclient "github.com/cloudwego/kitex/internal/streamx/streamxclient" + istreamxserver "github.com/cloudwego/kitex/internal/streamx/streamxserver" "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streamx" + "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream" "github.com/cloudwego/kitex/server" "github.com/cloudwego/kitex/server/streamxserver" + "github.com/cloudwego/kitex/transport" ) var providerTestCases []testCase @@ -157,16 +157,18 @@ func TestStreamingBasic(t *testing.T) { return err } }), - streamxserver.WithStreamMiddleware( + istreamxserver.WithStreamMiddleware( // middleware example: server streaming mode func(next streamx.StreamEndpoint) streamx.StreamEndpoint { return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + ri := rpcinfo.GetRPCInfo(ctx) + t.Logf("Server middleware before next: reqArgs=%v resArgs=%v streamArgs=%v", reqArgs.Req(), resArgs.Res(), streamArgs) test.Assert(t, streamArgs.Stream() != nil) test.Assert(t, validateMetadata(ctx)) - switch streamArgs.Stream().Mode() { + switch ri.Invocation().StreamingMode() { case streamx.StreamingUnary: test.Assert(t, reqArgs.Req() != nil) test.Assert(t, resArgs.Res() == nil) @@ -200,6 +202,8 @@ func TestStreamingBasic(t *testing.T) { err = next(ctx, streamArgs, reqArgs, resArgs) test.Assert(t, reqArgs.Req() == nil) test.Assert(t, resArgs.Res() == nil) + default: + t.Fatal("cannot get stream mode") } t.Logf("Server middleware after next: reqArgs=%v resArgs=%v streamArgs=%v err=%v", @@ -253,8 +257,9 @@ func TestStreamingBasic(t *testing.T) { return err } }), - streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + istreamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { + ri := rpcinfo.GetRPCInfo(ctx) // validate ctx test.Assert(t, validateMetadata(ctx)) @@ -262,7 +267,7 @@ func TestStreamingBasic(t *testing.T) { test.Assert(t, streamArgs.Stream() != nil) if err == nil { - switch streamArgs.Stream().Mode() { + switch ri.Invocation().StreamingMode() { case streamx.StreamingUnary: test.Assert(t, reqArgs.Req() != nil) test.Assert(t, resArgs.Res() != nil) @@ -278,6 +283,8 @@ func TestStreamingBasic(t *testing.T) { case streamx.StreamingBidirectional: test.Assert(t, reqArgs.Req() == nil) test.Assert(t, resArgs.Res() == nil) + default: + t.Fatal("cannot get stream mode") } } increaseIfNoError(&clientStreamCount, err) @@ -610,7 +617,7 @@ func TestStreamingException(t *testing.T) { client.WithHostPorts(addr), streamxclient.WithProvider(tc.ClientProvider), - streamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + istreamxclient.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { ri := rpcinfo.GetRPCInfo(ctx) test.Assert(t, ri.To().Address() != nil) @@ -661,7 +668,7 @@ func TestStreamingGoroutineLeak(t *testing.T) { addr, svr, err := NewTestServer( new(testService), streamxserver.WithProvider(tc.ServerProvider), - streamxserver.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { + istreamxserver.WithStreamMiddleware(func(next streamx.StreamEndpoint) streamx.StreamEndpoint { return func(ctx context.Context, streamArgs streamx.StreamArgs, reqArgs streamx.StreamReqArgs, resArgs streamx.StreamResArgs) (err error) { atomic.AddInt32(&streamStarted, 1) return next(ctx, streamArgs, reqArgs, resArgs) diff --git a/server/streamxserver/option.go b/server/streamxserver/option.go index 109f0c096a..2e765e1dd6 100644 --- a/server/streamxserver/option.go +++ b/server/streamxserver/option.go @@ -29,12 +29,6 @@ func WithProvider(provider streamx.ServerProvider) server.Option { }} } -func WithStreamMiddleware(mw streamx.StreamMiddleware) server.Option { - return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { - o.StreamX.StreamMiddlewares = append(o.StreamX.StreamMiddlewares, mw) - }} -} - func WithStreamRecvMiddleware(mw streamx.StreamRecvMiddleware) server.Option { return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { o.StreamX.StreamRecvMiddlewares = append(o.StreamX.StreamRecvMiddlewares, mw) From 19cf804fe54d1c71b27fabf783ec97358b3d3583 Mon Sep 17 00:00:00 2001 From: Joway Date: Fri, 15 Nov 2024 14:16:27 +0800 Subject: [PATCH 25/34] feat: using original trace ctx (#1613) --- client/client_streamx.go | 21 +++++++++++++++++++-- client/stream.go | 11 +++++------ pkg/streamx/stream_middleware_internal.go | 18 ++++++++++++++---- server/stream.go | 4 ++-- 4 files changed, 40 insertions(+), 14 deletions(-) diff --git a/client/client_streamx.go b/client/client_streamx.go index 16093c1980..cdd8d27de5 100644 --- a/client/client_streamx.go +++ b/client/client_streamx.go @@ -54,8 +54,25 @@ func (kc *kClient) NewStream(ctx context.Context, method string, streamArgs stre if msargs := streamx.AsMutableStreamArgs(streamArgs); msargs != nil { msargs.SetStreamMiddleware(kc.sxStreamMW) - msargs.SetStreamRecvMiddleware(kc.sxStreamRecvMW) - msargs.SetStreamSendMiddleware(kc.sxStreamSendMW) + + eventHandler := kc.opt.TracerCtl.GetStreamEventHandler() + if eventHandler == nil { + msargs.SetStreamRecvMiddleware(kc.sxStreamRecvMW) + msargs.SetStreamSendMiddleware(kc.sxStreamSendMW) + } else { + traceRecvMW := streamx.NewStreamRecvStatMiddleware(ctx, eventHandler) + traceSendMW := streamx.NewStreamSendStatMiddleware(ctx, eventHandler) + if kc.sxStreamRecvMW == nil { + msargs.SetStreamRecvMiddleware(traceRecvMW) + } else { + msargs.SetStreamRecvMiddleware(streamx.StreamRecvMiddlewareChain(traceRecvMW, kc.sxStreamRecvMW)) + } + if kc.sxStreamSendMW == nil { + msargs.SetStreamSendMiddleware(traceSendMW) + } else { + msargs.SetStreamSendMiddleware(streamx.StreamSendMiddlewareChain(traceSendMW, kc.sxStreamSendMW)) + } + } } // with streamx mode, req is nil and resp is streamArgs // it's an ugly trick but if we don't want to refactor too much, diff --git a/client/stream.go b/client/stream.go index 5e2219bcc0..cf656503ae 100644 --- a/client/stream.go +++ b/client/stream.go @@ -82,13 +82,12 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { // streamx version streaming mw kc.sxStreamMW = streamx.StreamMiddlewareChain(kc.opt.StreamX.StreamMWs...) - eventHandler := kc.opt.TracerCtl.GetStreamEventHandler() - if eventHandler != nil { - kc.opt.StreamX.StreamRecvMWs = append(kc.opt.StreamX.StreamRecvMWs, streamx.NewStreamRecvStatMiddleware(eventHandler)) - kc.opt.StreamX.StreamSendMWs = append(kc.opt.StreamX.StreamSendMWs, streamx.NewStreamSendStatMiddleware(eventHandler)) + if len(kc.opt.StreamX.StreamRecvMWs) > 0 { + kc.sxStreamRecvMW = streamx.StreamRecvMiddlewareChain(kc.opt.StreamX.StreamRecvMWs...) + } + if len(kc.opt.StreamX.StreamSendMWs) > 0 { + kc.sxStreamSendMW = streamx.StreamSendMiddlewareChain(kc.opt.StreamX.StreamSendMWs...) } - kc.sxStreamRecvMW = streamx.StreamRecvMiddlewareChain(kc.opt.StreamX.StreamRecvMWs...) - kc.sxStreamSendMW = streamx.StreamSendMiddlewareChain(kc.opt.StreamX.StreamSendMWs...) return func(ctx context.Context, req, resp interface{}) (err error) { // req and resp as &streaming.Stream diff --git a/pkg/streamx/stream_middleware_internal.go b/pkg/streamx/stream_middleware_internal.go index d06d7d491d..a6d5d2daf1 100644 --- a/pkg/streamx/stream_middleware_internal.go +++ b/pkg/streamx/stream_middleware_internal.go @@ -22,21 +22,31 @@ import ( "github.com/cloudwego/kitex/pkg/stats" ) -func NewStreamRecvStatMiddleware(ehandler EventHandler) StreamRecvMiddleware { +func NewStreamRecvStatMiddleware(traceCtx context.Context, ehandler EventHandler) StreamRecvMiddleware { return func(next StreamRecvEndpoint) StreamRecvEndpoint { return func(ctx context.Context, stream Stream, res any) (err error) { err = next(ctx, stream, res) - ehandler(ctx, stats.StreamRecv, err) + // if traceCtx is nil, using the current ctx + // otherwise, we should use the original trace ctx instead + if traceCtx == nil { + traceCtx = ctx + } + ehandler(traceCtx, stats.StreamRecv, err) return err } } } -func NewStreamSendStatMiddleware(ehandler EventHandler) StreamSendMiddleware { +func NewStreamSendStatMiddleware(traceCtx context.Context, ehandler EventHandler) StreamSendMiddleware { return func(next StreamSendEndpoint) StreamSendEndpoint { return func(ctx context.Context, stream Stream, res any) (err error) { err = next(ctx, stream, res) - ehandler(ctx, stats.StreamSend, err) + // if traceCtx is nil, using the current ctx + // otherwise, we should use the original trace ctx instead + if traceCtx == nil { + traceCtx = ctx + } + ehandler(traceCtx, stats.StreamSend, err) return err } } diff --git a/server/stream.go b/server/stream.go index 59b844bb1a..e742be4b28 100644 --- a/server/stream.go +++ b/server/stream.go @@ -34,10 +34,10 @@ func (s *server) initStreamMiddlewares(ctx context.Context) { ehandler := s.opt.TracerCtl.GetStreamEventHandler() if ehandler != nil { s.opt.StreamX.StreamRecvMiddlewares = append( - s.opt.StreamX.StreamRecvMiddlewares, streamx.NewStreamRecvStatMiddleware(ehandler), + s.opt.StreamX.StreamRecvMiddlewares, streamx.NewStreamRecvStatMiddleware(nil, ehandler), ) s.opt.StreamX.StreamSendMiddlewares = append( - s.opt.StreamX.StreamSendMiddlewares, streamx.NewStreamSendStatMiddleware(ehandler), + s.opt.StreamX.StreamSendMiddlewares, streamx.NewStreamSendStatMiddleware(nil, ehandler), ) } } From 9401a78a12f914b5607f87fb82ac980539af9088 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Fri, 15 Nov 2024 17:23:45 +0800 Subject: [PATCH 26/34] chore: fix service name and check str header and int header --- pkg/streamx/provider/ttstream/client_provier.go | 5 ++++- pkg/streamx/streamx_gen_service_test.go | 2 +- pkg/streamx/streamx_user_test.go | 4 ---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index f7556c7068..73ccac84b4 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -71,8 +71,11 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (stre if err != nil { return nil, err } - } else { + } + if intHeader == nil { intHeader = IntHeader{} + } + if strHeader == nil { strHeader = map[string]string{} } strHeader[ttheader.HeaderIDLServiceName] = c.sinfo.ServiceName diff --git a/pkg/streamx/streamx_gen_service_test.go b/pkg/streamx/streamx_gen_service_test.go index bf2828fd6d..8bbaf8f3ca 100644 --- a/pkg/streamx/streamx_gen_service_test.go +++ b/pkg/streamx/streamx_gen_service_test.go @@ -32,7 +32,7 @@ import ( // --- Define Service Method handler --- var testServiceInfo = &serviceinfo.ServiceInfo{ - ServiceName: "kitex.echo.service", + ServiceName: "TestService", PayloadCodec: serviceinfo.Thrift, HandlerType: (*TestService)(nil), Methods: map[string]serviceinfo.MethodInfo{ diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index 0f87a9e6ca..6b223f5856 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -598,10 +598,6 @@ func TestStreamingBasic(t *testing.T) { func TestStreamingException(t *testing.T) { for _, tc := range providerTestCases { t.Run(tc.Name, func(t *testing.T) { - addr := test.GetLocalAddress() - ln, _ := netpoll.CreateListener("tcp", addr) - defer ln.Close() - // create server addr, svr, err := NewTestServer( new(testService), From 694f7e97279b78b15a7ade06f879fdf155fc566b Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Fri, 15 Nov 2024 18:39:28 +0800 Subject: [PATCH 27/34] refactor: refine streamx generation code (#1610) --- pkg/streamx/streamx_gen_service_test.go | 7 +- tool/cmd/kitex/args/args.go | 2 +- tool/internal_pkg/generator/generator.go | 171 +++++------------- tool/internal_pkg/generator/type.go | 3 + .../pluginmode/thriftgo/convertor.go | 3 + tool/internal_pkg/tpl/client.go | 50 ++++- tool/internal_pkg/tpl/handler.method.go | 22 ++- tool/internal_pkg/tpl/server.go | 24 ++- tool/internal_pkg/tpl/service.go | 107 ++++++++++- tool/internal_pkg/tpl/streamx/client.go | 107 ----------- .../tpl/streamx/handler.method.go | 37 ---- tool/internal_pkg/tpl/streamx/server.go | 61 ------- tool/internal_pkg/tpl/streamx/service.go | 68 ------- 13 files changed, 244 insertions(+), 418 deletions(-) delete mode 100644 tool/internal_pkg/tpl/streamx/client.go delete mode 100644 tool/internal_pkg/tpl/streamx/handler.method.go delete mode 100644 tool/internal_pkg/tpl/streamx/server.go delete mode 100644 tool/internal_pkg/tpl/streamx/service.go diff --git a/pkg/streamx/streamx_gen_service_test.go b/pkg/streamx/streamx_gen_service_test.go index 8bbaf8f3ca..f44ba68e66 100644 --- a/pkg/streamx/streamx_gen_service_test.go +++ b/pkg/streamx/streamx_gen_service_test.go @@ -20,6 +20,7 @@ import ( "context" "github.com/cloudwego/kitex/client" + "github.com/cloudwego/kitex/client/callopt" "github.com/cloudwego/kitex/client/streamxclient" "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -84,7 +85,7 @@ var testServiceInfo = &serviceinfo.ServiceInfo{ ), "ServerStream": serviceinfo.NewMethodInfo( func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeServerStreamHandler( + return streamxserver.InvokeServerStreamHandler[Request, Response]( ctx, reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs), func(ctx context.Context, req *Request, stream streamx.ServerStreamingServer[Response]) error { return handler.(TestService).ServerStream(ctx, req, stream) @@ -215,7 +216,7 @@ type TestService interface { // --- Define Client Implementation Interface --- type TestServiceClient interface { - PingPong(ctx context.Context, req *Request) (r *Response, err error) + PingPong(ctx context.Context, req *Request, callOptions ...callopt.Option) (r *Response, err error) Unary(ctx context.Context, req *Request, callOptions ...streamxcallopt.CallOption) (r *Response, err error) ClientStream(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( @@ -240,7 +241,7 @@ type kClient struct { caller client.Client } -func (c *kClient) PingPong(ctx context.Context, req *Request) (r *Response, err error) { +func (c *kClient) PingPong(ctx context.Context, req *Request, callOptions ...callopt.Option) (r *Response, err error) { var _args ServerPingPongArgs _args.Req = req var _result ServerPingPongResult diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index 0198a29420..620fbc52a6 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -253,7 +253,7 @@ func (a *Arguments) checkStreamX() error { // set TTHeader Streaming by default a.Protocol = transport.TTHeader.String() } - // todo: process pb and gRPC + // todo(DMwangnima): process pb and gRPC return nil } diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 29d444d271..ec46085467 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -26,7 +26,6 @@ import ( "github.com/cloudwego/kitex/tool/internal_pkg/log" "github.com/cloudwego/kitex/tool/internal_pkg/tpl" - "github.com/cloudwego/kitex/tool/internal_pkg/tpl/streamx" "github.com/cloudwego/kitex/tool/internal_pkg/util" "github.com/cloudwego/kitex/transport" ) @@ -303,9 +302,6 @@ func (c *Config) ApplyExtension() error { } func (c *Config) IsUsingMultipleServicesTpl() bool { - if c.StreamX { - return true - } for _, part := range c.BuiltinTpl { if part == MultipleServicesTpl { return true @@ -436,19 +432,10 @@ func (g *generator) generateHandler(pkg *PackageInfo, svc *ServiceInfo, handlerF return f, nil } - var task Task - if g.StreamX && svc.HasStreaming { - task = Task{ - Name: HandlerFileName, - Path: handlerFilePath, - Text: tpl.HandlerTpl + "\n" + streamx.HandlerMethodsTpl, - } - } else { - task = Task{ - Name: HandlerFileName, - Path: handlerFilePath, - Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl, - } + task := Task{ + Name: HandlerFileName, + Path: handlerFilePath, + Text: tpl.HandlerTpl + "\n" + tpl.HandlerMethodsTpl, } g.setImports(task.Name, pkg) handle := func(task *Task, pkg *PackageInfo) (*File, error) { @@ -488,11 +475,6 @@ func (g *generator) GenerateService(pkg *PackageInfo) ([]*File, error) { Path: util.JoinPath(output, svcPkg+".go"), Text: tpl.ServiceTpl, } - if g.StreamX && pkg.ServiceInfo.HasStreaming { - cliTask.Text = streamx.ClientTpl - svrTask.Text = streamx.ServerTpl - svcTask.Text = streamx.ServiceTpl - } tasks := []*Task{cliTask, svrTask, svcTask} // do not generate invoker.go in service package by default @@ -562,31 +544,28 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { pkg.Imports = make(map[string]map[string]bool) switch name { case ClientFileName: - if g.StreamX && pkg.HasStreaming { - g.setStreamXClientImports(pkg) - } else { - pkg.AddImports("client") - if pkg.HasStreaming { - pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming") - pkg.AddImport("transport", "github.com/cloudwego/kitex/transport") - } - if len(pkg.AllMethods()) > 0 { - if needCallOpt(pkg) { - pkg.AddImports("callopt") - } - pkg.AddImports("context") + pkg.AddImports("client") + if !g.StreamX && pkg.HasStreaming { + pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming") + pkg.AddImport("transport", "github.com/cloudwego/kitex/transport") + } + if len(pkg.AllMethods()) > 0 { + if needCallOpt(pkg) { + pkg.AddImports("callopt") } + pkg.AddImports("context") } fallthrough case HandlerFileName: - if g.StreamX && pkg.HasStreaming { - g.setStreamXHandlerImports(pkg) - return - } for _, m := range pkg.ServiceInfo.AllMethods() { - if !m.ServerStreaming && !m.ClientStreaming { + // for StreamX interface, every method in handler has ctx argument + // for old interface, streaming method in handler does not have ctx argument + if g.StreamX || (!m.ServerStreaming && !m.ClientStreaming) { pkg.AddImports("context") } + if g.StreamX && m.Streaming.IsStreaming { + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") + } for _, a := range m.Args { for _, dep := range a.Deps { pkg.AddImport(dep.PkgRefName, dep.ImportPath) @@ -599,19 +578,20 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { } } case ServerFileName, InvokerFileName: - if g.StreamX && pkg.HasStreaming { - g.setStreamXServerImports(pkg) - return + // for StreamX, if there is streaming method, generate Server Interface in server.go + if g.StreamX { + for _, method := range pkg.AllMethods() { + if method.Streaming.IsStreaming { + pkg.AddImports("context") + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") + } + } } if len(pkg.CombineServices) == 0 { pkg.AddImport(pkg.ServiceInfo.PkgRefName, pkg.ServiceInfo.ImportPath) } pkg.AddImports("server") case ServiceFileName: - if g.StreamX && pkg.HasStreaming { - g.setStreamXServiceImports(pkg) - return - } pkg.AddImports("errors") pkg.AddImports("client") pkg.AddImport("kitex", "github.com/cloudwego/kitex/pkg/serviceinfo") @@ -620,9 +600,6 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { pkg.AddImports("context") } for _, m := range pkg.ServiceInfo.AllMethods() { - if m.ClientStreaming || m.ServerStreaming { - pkg.AddImports("fmt") - } if m.GenArgResultStruct { pkg.AddImports("proto") } else { @@ -634,9 +611,22 @@ func (g *generator) setImports(name string, pkg *PackageInfo) { pkg.AddImport(dep.PkgRefName, dep.ImportPath) } } - if m.Streaming.IsStreaming || pkg.Codec == "protobuf" { - // protobuf handler support both PingPong and Unary (streaming) requests - pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming") + // streaming imports + if !g.StreamX { + if m.Streaming.IsStreaming || pkg.Codec == "protobuf" { + // protobuf handler support both PingPong and Unary (streaming) requests + pkg.AddImport("streaming", "github.com/cloudwego/kitex/pkg/streaming") + } + if m.ClientStreaming || m.ServerStreaming { + pkg.AddImports("fmt") + } + } else { + if m.Streaming.IsStreaming { + pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient") + pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient/streamxcallopt") + pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") + pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver") + } } if !m.Void && m.Resp != nil { for _, dep := range m.Resp.Deps { @@ -685,78 +675,3 @@ func needCallOpt(pkg *PackageInfo) bool { } return needCallOpt } - -func (g *generator) setStreamXClientImports(pkg *PackageInfo) { - pkg.AddImports("client") - pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient") - if len(pkg.AllMethods()) > 0 { - pkg.AddImports("context") - pkg.AddImports("github.com/cloudwego/kitex/client/streamxclient/streamxcallopt") - pkg.AddImports("github.com/cloudwego/kitex/pkg/serviceinfo") - } - pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") - if g.IDLType == "thrift" { - pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef) - } -} - -func (g *generator) setStreamXServerImports(pkg *PackageInfo) { - pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") - pkg.AddImports("server") - pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver") - if g.IDLType == "thrift" { - pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef) - } - for _, m := range pkg.AllMethods() { - pkg.AddImports("context") - for _, a := range m.Args { - for _, dep := range a.Deps { - pkg.AddImport(dep.PkgRefName, dep.ImportPath) - } - } - if !m.Void && m.Resp != nil { - for _, dep := range m.Resp.Deps { - pkg.AddImport(dep.PkgRefName, dep.ImportPath) - } - } - } -} - -func (g *generator) setStreamXServiceImports(pkg *PackageInfo) { - pkg.AddImports("github.com/cloudwego/kitex/pkg/serviceinfo") - for _, m := range pkg.AllMethods() { - pkg.AddImports("context") - pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") - pkg.AddImports("github.com/cloudwego/kitex/server/streamxserver") - for _, a := range m.Args { - for _, dep := range a.Deps { - pkg.AddImport(dep.PkgRefName, dep.ImportPath) - } - } - if !m.Void && m.Resp != nil { - for _, dep := range m.Resp.Deps { - pkg.AddImport(dep.PkgRefName, dep.ImportPath) - } - } - } -} - -func (g *generator) setStreamXHandlerImports(pkg *PackageInfo) { - for _, m := range pkg.ServiceInfo.AllMethods() { - pkg.AddImports("context") - pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx") - if g.IDLType == "thrift" { - pkg.AddImports("github.com/cloudwego/kitex/pkg/streamx/provider/" + streamxTTHeaderRef) - } - for _, a := range m.Args { - for _, dep := range a.Deps { - pkg.AddImport(dep.PkgRefName, dep.ImportPath) - } - } - if !m.Void && m.Resp != nil { - for _, dep := range m.Resp.Deps { - pkg.AddImport(dep.PkgRefName, dep.ImportPath) - } - } - } -} diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index efdae50ef9..02178ded45 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -159,6 +159,8 @@ type ServiceInfo struct { RefName string // identify whether this service would generate a corresponding handler. GenerateHandler bool + // whether to generate StreamX interface code + StreamX bool } // AllMethods returns all methods that the service have. @@ -212,6 +214,7 @@ type MethodInfo struct { ClientStreaming bool ServerStreaming bool Streaming *streaming.Streaming + StreamX bool } // Parameter . diff --git a/tool/internal_pkg/pluginmode/thriftgo/convertor.go b/tool/internal_pkg/pluginmode/thriftgo/convertor.go index ff5498afb3..13f5d8f15d 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/convertor.go +++ b/tool/internal_pkg/pluginmode/thriftgo/convertor.go @@ -383,6 +383,7 @@ func (c *converter) convertTypes(req *plugin.Request) error { ServiceFilePath: ast.Filename, HasStreaming: hasStreaming, GenerateHandler: true, + StreamX: c.Config.StreamX, } if c.IsHessian2() { @@ -421,6 +422,7 @@ func (c *converter) makeService(pkg generator.PkgInfo, svc *golang.Service) (*ge PkgInfo: pkg, ServiceName: svc.GoName().String(), RawServiceName: svc.Name, + StreamX: c.Config.StreamX, } si.ServiceTypeName = func() string { return si.PkgRefName + "." + si.ServiceName } @@ -464,6 +466,7 @@ func (c *converter) makeMethod(si *generator.ServiceInfo, f *golang.Function) (* ClientStreaming: st.ClientStreaming, ServerStreaming: st.ServerStreaming, ArgsLength: len(f.Arguments()), + StreamX: si.StreamX, } if st.IsStreaming { si.HasStreaming = true diff --git a/tool/internal_pkg/tpl/client.go b/tool/internal_pkg/tpl/client.go index 666e966898..6044775ed9 100644 --- a/tool/internal_pkg/tpl/client.go +++ b/tool/internal_pkg/tpl/client.go @@ -29,8 +29,13 @@ import ( {{- end}} {{- end}} {{- if .HasStreaming}} + {{- if not .StreamX}} "github.com/cloudwego/kitex/client/streamclient" "github.com/cloudwego/kitex/client/callopt/streamcall" + {{- else}} + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/streamx" + {{- end}}{{- /* if not .StreamX end */}} {{- end}} ) // Client is designed to provide IDL-compatible methods with call-option parameter for kitex framework. @@ -43,9 +48,23 @@ type Client interface { {{.Name}}(ctx context.Context {{range .Args}}, {{.RawName}} {{.Type}}{{end}}, callOptions ...callopt.Option ) ({{if not .Void}}r {{.Resp.Type}}, {{end}}err error) {{- end}} {{- end}} +{{- /* Streamx interface for streaming method */}} +{{- if and .StreamX (eq $.Codec "thrift") .Streaming.IsStreaming}} +{{- $streamingUnary := (eq .Streaming.Mode "unary")}} +{{- $clientSide := (eq .Streaming.Mode "client")}} +{{- $serverSide := (eq .Streaming.Mode "server")}} +{{- $bidiSide := (eq .Streaming.Mode "bidirectional")}} +{{- $arg := index .Args 0}} + {{.Name}}{{- if $streamingUnary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (r {{.Resp.Type}}, err error) + {{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], error) + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ServerStreamingClient[{{NotPtr .Resp.Type}}], error) + {{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.BidiStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], error) + {{- end}} +{{- end}}{{- /* if and .StreamX (eq $.Codec "thrift") .Streaming.IsStreaming end */}} {{- end}} } +{{- if not .StreamX}} {{- if .HasStreaming}} // StreamClient is designed to provide Interface for Streaming APIs. type StreamClient interface { @@ -75,6 +94,7 @@ type {{.ServiceName}}_{{.RawName}}Client interface { } {{- end}} {{end}} +{{- end}}{{- /* if not .StreamX end */}} // NewClient creates a client for the service defined in IDL. func NewClient(destService string, opts ...client.Option) (Client, error) { @@ -87,9 +107,13 @@ func NewClient(destService string, opts ...client.Option) (Client, error) { {{end}} options = append(options, opts...) - kc, err := client.NewClient( - {{- if eq $.Codec "protobuf"}}serviceInfo(){{else}}serviceInfoForClient(){{end -}} - , options...) + {{- if and .StreamX .HasStreaming}} + kc, err := client.NewClient(serviceInfo(), options...) + {{- else}} + kc, err := client.NewClient( + {{- if eq $.Codec "protobuf"}}serviceInfo(){{else}}serviceInfoForClient(){{end -}} + , options...) + {{- end}}{{/* if .StreamX .HasStreaming end */}} if err != nil { return nil, err } @@ -131,8 +155,27 @@ func (p *k{{$.ServiceName}}Client) {{.Name}}(ctx context.Context {{range .Args}} } {{- end}} {{- end}} +{{- /* Streamx interface for streaming method */}} +{{- if and .StreamX (eq $.Codec "thrift") .Streaming.IsStreaming}} +{{- $streamingUnary := (eq .Streaming.Mode "unary")}} +{{- $clientSide := (eq .Streaming.Mode "client")}} +{{- $serverSide := (eq .Streaming.Mode "server")}} +{{- $bidiSide := (eq .Streaming.Mode "bidirectional")}} +{{- $arg := index .Args 0}} +func (p *k{{$.ServiceName}}Client) {{.Name}}{{- if $streamingUnary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (r {{.Resp.Type}}, err error) { + return p.kClient.{{.Name}}(ctx, req, callOptions...) + {{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ClientStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], error) { + return p.kClient.{{.Name}}(ctx, callOptions...) + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.ServerStreamingClient[{{NotPtr .Resp.Type}}], error) { + return p.kClient.{{.Name}}(ctx, req, callOptions...) + {{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (context.Context, streamx.BidiStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], error) { + return p.kClient.{{.Name}}(ctx, callOptions...) + {{- end}} +} +{{- end}}{{- /* if and .StreamX (eq $.Codec "thrift") .Streaming.IsStreaming end */}} {{end}} +{{- if not .StreamX}} {{- if .HasStreaming}} // NewStreamClient creates a stream client for the service's streaming APIs defined in IDL. func NewStreamClient(destService string, opts ...streamclient.Option) (StreamClient, error) { @@ -181,5 +224,6 @@ func (p *k{{$.ServiceName}}StreamClient) {{.Name}}(ctx context.Context {{range . {{- end}} {{end}} {{- end}} +{{- end}}{{- /* if not .StreamX end */}} {{template "@client.go-EOF" .}} ` diff --git a/tool/internal_pkg/tpl/handler.method.go b/tool/internal_pkg/tpl/handler.method.go index efbc3653f2..88a6904c63 100644 --- a/tool/internal_pkg/tpl/handler.method.go +++ b/tool/internal_pkg/tpl/handler.method.go @@ -17,6 +17,21 @@ package tpl // HandlerMethodsTpl is the template for generating methods in handler.go. var HandlerMethodsTpl string = `{{define "HandlerMethod"}} {{range .AllMethods}} +{{- if and .StreamX .Streaming.IsStreaming}} +{{- $streamingUnary := (eq .Streaming.Mode "unary")}} +{{- $clientSide := (eq .Streaming.Mode "client")}} +{{- $serverSide := (eq .Streaming.Mode "server")}} +{{- $bidiSide := (eq .Streaming.Mode "bidirectional")}} +{{- $arg := index .Args 0}} +func (s *{{.ServiceName}}Impl) {{.Name}}{{- if $streamingUnary}}(ctx context.Context, req {{$arg.Type}}) (resp {{.Resp.Type}}, err error) { + {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (resp {{.Resp.Type}}, err error) { + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{NotPtr .Resp.Type}}]) (err error) { + {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (err error) { + {{- end}} + // TODO: Your code here... + return +} +{{- else}} {{- if or .ClientStreaming .ServerStreaming}} func (s *{{$.ServiceName}}Impl) {{.Name}}({{if not .ClientStreaming}}{{range .Args}}{{LowerFirst .Name}} {{.Type}}, {{end}}{{end}}stream {{.PkgRefName}}.{{.ServiceName}}_{{.RawName}}Server) (err error) { println("{{.Name}} called") @@ -40,8 +55,9 @@ func (s *{{$.ServiceName}}Impl) {{.Name}}(ctx context.Context {{range .Args}}, { // TODO: Your code here... return } -{{end}} -{{end}} -{{end}} +{{end}}{{/* if .Void end */}} +{{end}}{{/* if or .ClientStreaming .ServerStreaming end */}} +{{end}}{{/* if and .StreamX .Streaming.IsStreaming end */}} +{{end}}{{/* range .AllMethods end */}} {{end}}{{/* define "HandlerMethod" */}} ` diff --git a/tool/internal_pkg/tpl/server.go b/tool/internal_pkg/tpl/server.go index d292a97932..78d7cfdc9f 100644 --- a/tool/internal_pkg/tpl/server.go +++ b/tool/internal_pkg/tpl/server.go @@ -30,8 +30,28 @@ import ( {{- end}} ) +{{- $serverInterfaceName := (call .ServiceTypeName) }} +{{- if and .StreamX .HasStreaming}} +{{- $serverInterfaceName = .ServiceName }} +type {{.ServiceName}} interface { +{{- range .AllMethods}} +{{- $streamingUnary := (eq .Streaming.Mode "unary")}} +{{- $clientSide := (eq .Streaming.Mode "client")}} +{{- $serverSide := (eq .Streaming.Mode "server")}} +{{- $bidiSide := (eq .Streaming.Mode "bidirectional")}} +{{- $arg := index .Args 0}} + {{.Name}}{{- if $streamingUnary}}(ctx context.Context, req {{$arg.Type}}) ({{.Resp.Type}}, error) + {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) ({{.Resp.Type}}, error) + {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{NotPtr .Resp.Type}}]) error + {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) error + {{- else }}(ctx context.Context {{range .Args}}, {{.RawName}} {{.Type}}{{end}}) ({{if not .Void}}r {{.Resp.Type}}, {{end}}err error) + {{- end}} +{{- end}} +} +{{- end}} {{- /* if and .StreamX .HasStreaming end */}} + // NewServer creates a server.Server with the given handler and options. -func NewServer(handler {{call .ServiceTypeName}}, opts ...server.Option) server.Server { +func NewServer(handler {{$serverInterfaceName}}, opts ...server.Option) server.Server { var options []server.Option {{template "@server.go-NewServer-option" .}} options = append(options, opts...) @@ -50,7 +70,7 @@ func NewServer(handler {{call .ServiceTypeName}}, opts ...server.Option) server. } {{template "@server.go-EOF" .}} -func RegisterService(svr server.Server, handler {{call .ServiceTypeName}}, opts ...server.RegisterOption) error { +func RegisterService(svr server.Server, handler {{$serverInterfaceName}}, opts ...server.RegisterOption) error { return svr.RegisterService(serviceInfo(), handler, opts...) } ` diff --git a/tool/internal_pkg/tpl/service.go b/tool/internal_pkg/tpl/service.go index d5a72428dc..745eb802af 100644 --- a/tool/internal_pkg/tpl/service.go +++ b/tool/internal_pkg/tpl/service.go @@ -58,21 +58,43 @@ var serviceMethods = map[string]kitex.MethodInfo{ {{- else -}} kitex.StreamingNone {{- end -}} {{- end}}), + {{- if and .StreamX .Streaming.IsStreaming}} + kitex.WithMethodExtra("streamx", "true"), + {{- end}} ), {{- end}} } +{{- if and .StreamX .HasStreaming}} +var {{LowerFirst .ServiceName}}ServiceInfo = NewServiceInfo() +{{- else}} var ( {{LowerFirst .ServiceName}}ServiceInfo = NewServiceInfo() {{LowerFirst .ServiceName}}ServiceInfoForClient = NewServiceInfoForClient() {{LowerFirst .ServiceName}}ServiceInfoForStreamClient = NewServiceInfoForStreamClient() ) +{{- end}} {{- /* if and .StreamX .HasStreaming end */}} // for server func serviceInfo() *kitex.ServiceInfo { return {{LowerFirst .ServiceName}}ServiceInfo } +{{- if and .StreamX .HasStreaming}} +// NewServiceInfo creates a new ServiceInfo containing all methods +{{- /* It's for the Server (providing both streaming/non-streaming APIs), or for the grpc client */}} +func NewServiceInfo() *kitex.ServiceInfo { + return newServiceInfo() +} + +func newServiceInfo() *kitex.ServiceInfo { + return &kitex.ServiceInfo{ + ServiceName: "{{.RawServiceName}}", + PayloadCodec: kitex.Thrift, + Methods: serviceMethods, + } +} +{{- else}} {{- /* old streaming interface */}} // for stream client func serviceInfoForStreamClient() *kitex.ServiceInfo { return {{LowerFirst .ServiceName}}ServiceInfoForStreamClient @@ -140,15 +162,24 @@ func newServiceInfo(hasStreaming bool, keepStreamingMethods bool, keepNonStreami } return svcInfo } +{{- end}}{{- /* if and .StreamX .HasStreaming end */}} {{range .AllMethods}} {{- $isStreaming := or .ClientStreaming .ServerStreaming}} +{{- $streamingUnary := (eq .Streaming.Mode "unary")}} {{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} {{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} {{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} {{- $bidiSide := and .ClientStreaming .ServerStreaming}} -{{- $arg := "" }} -{{- if or (eq $.Codec "protobuf") ($isStreaming) }} +{{- $arg := ""}} +{{- $handlerFunc := ""}} +{{- $mode := ""}} + {{- if $streamingUnary -}} {{- $mode = "serviceinfo.StreamingUnary" }} {{- $handlerFunc = "InvokeUnaryHandler" }} + {{- else if $serverSide -}} {{- $mode = "serviceinfo.StreamingServer" }} {{- $handlerFunc = "InvokeServerStreamHandler" }} + {{- else if $clientSide -}} {{- $mode = "serviceinfo.StreamingClient" }} {{- $handlerFunc = "InvokeClientStreamHandler" }} + {{- else if $bidiSide -}} {{- $mode = "serviceinfo.StreamingBidirectional" }} {{- $handlerFunc = "InvokeBidiStreamHandler" }} + {{- end}} +{{- if or (eq $.Codec "protobuf") ($isStreaming) (.Streaming.IsStreaming) }} {{- $arg = index .Args 0}}{{/* streaming api only supports exactly one argument */}} {{- end}} @@ -194,8 +225,23 @@ func {{LowerFirst .Name}}Handler(ctx context.Context, handler interface{}, arg, return handler.({{.PkgRefName}}.{{.ServiceName}}).{{.Name}}({{if $serverSide}}req, {{end}}stream) {{- end}} {{/* $unary end */}} {{- else}} {{/* thrift logic */}} + {{- if and .StreamX .Streaming.IsStreaming}} + return streamxserver.{{$handlerFunc}}[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, arg.(streamx.StreamReqArgs), result.(streamx.StreamResArgs), + {{- if $streamingUnary }}func(ctx context.Context, req {{$arg.Type}}) ({{.Resp.Type}}, error) { + return handler.({{.ServiceName}}).{{.Name}}(ctx, req) + {{- else if $clientSide }}func(ctx context.Context, stream streamx.ClientStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) ({{.Resp.Type}}, error) { + return handler.({{.ServiceName}}).{{.Name}}(ctx, stream) + {{- else if $serverSide }}func(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{NotPtr .Resp.Type}}]) error { + return handler.({{.ServiceName}}).{{.Name}}(ctx, req, stream) + {{- else if $bidiSide }}func(ctx context.Context, stream streamx.BidiStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) error { + return handler.({{.ServiceName}}).{{.Name}}(ctx, stream) + {{- end}} + }, + ) + {{- else}}{{/* old interface generation code */}} {{- if $unary}} {{/* unary logic */}} - {{- if eq .Streaming.Mode "unary"}} + {{- if $streamingUnary}} if streaming.GetStream(ctx) == nil { return errors.New("{{.ServiceName}}.{{.Name}} is a thrift streaming unary method, please call with Kitex StreamClient or remove the annotation streaming.mode") } @@ -244,10 +290,12 @@ func {{LowerFirst .Name}}Handler(ctx context.Context, handler interface{}, arg, } return handler.({{.PkgRefName}}.{{.ServiceName}}).{{.Name}}(req, stream) {{- end}} {{/* $serverSide end*/}} - {{- end}} {{/* thrift end */}} + {{- end}} + {{- end}} {{- end}} {{/* protobuf end */}} } +{{- if not .StreamX }} {{- /* define streaming struct */}} {{- if $isStreaming}} type {{LowerFirst .ServiceName}}{{.RawName}}Client struct { @@ -311,6 +359,7 @@ func (x *{{LowerFirst .ServiceName}}{{.RawName}}Server) Recv() ({{$arg.Type}}, e {{- end}} {{- /* define streaming struct end */}} +{{- end}}{{- /* if not .StreamX end*/}} func new{{.ArgStructName}}() interface{} { return {{if not .GenArgResultStruct}}{{.PkgRefName}}.New{{.ArgStructName}}(){{else}}&{{.ArgStructName}}{}{{end}} } @@ -449,17 +498,64 @@ func (p *{{.ResStructName}}) GetResult() interface{} { type kClient struct { c client.Client + {{- if and .StreamX .HasStreaming}} + streamer client.StreamX + {{- end}} } func newServiceClient(c client.Client) *kClient { return &kClient{ c: c, + {{- if and .StreamX .HasStreaming}} + streamer: c.(client.StreamX), + {{- end}} } } {{range .AllMethods}} -{{- if or .ClientStreaming .ServerStreaming}} {{- /* streaming logic */}} +{{- if and .StreamX .Streaming.IsStreaming}} +{{- $streamingUnary := (eq .Streaming.Mode "unary")}} +{{- $clientSide := (eq .Streaming.Mode "client")}} +{{- $serverSide := (eq .Streaming.Mode "server")}} +{{- $bidiSide := (eq .Streaming.Mode "bidirectional")}} +{{- $mode := ""}} + {{- if $bidiSide -}} {{- $mode = "kitex.StreamingBidirectional" }} + {{- else if $serverSide -}} {{- $mode = "kitex.StreamingServer" }} + {{- else if $clientSide -}} {{- $mode = "kitex.StreamingClient" }} + {{- else if $streamingUnary -}} {{- $mode = "kitex.StreamingUnary" }} + {{- end}} +{{- $arg := index .Args 0}} +func (p *kClient) {{.Name}}{{- if $streamingUnary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) ({{.Resp.Type}}, error) { + res := new({{NotPtr .Resp.Type}}) + _, _, err := streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, p.streamer, {{$mode}}, "{{.RawName}}", req, res, callOptions...) + if err != nil { + return nil, err + } + return res, nil +} +{{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + context.Context, streamx.ClientStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], error, +) { + return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, p.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) +} +{{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) ( + context.Context, streamx.ServerStreamingClient[{{NotPtr .Resp.Type}}], error, +) { + return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, p.streamer, {{$mode}}, "{{.RawName}}", req, nil, callOptions...) +} +{{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( + context.Context, streamx.BidiStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], error, +) { + return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( + ctx, p.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) +} +{{- end}}{{/* if $streamingUnary end */}} +{{- else}}{{- /* old streaming interface */}} +{{- if or .ClientStreaming .ServerStreaming}} func (p *kClient) {{.Name}}(ctx context.Context{{if not .ClientStreaming}}{{range .Args}}, {{LowerFirst .Name}} {{.Type}}{{end}}{{end}}) ({{.ServiceName}}_{{.RawName}}Client, error) { streamClient, ok := p.c.(client.Streaming) if !ok { @@ -525,6 +621,7 @@ func (p *kClient) {{.Name}}(ctx context.Context {{range .Args}}, {{.RawName}} {{ {{end -}} } {{- end}} +{{- end}}{{/* if and .StreamX .Streaming.IsStreaming end */}} {{end}} {{- if .FrugalPretouch}} diff --git a/tool/internal_pkg/tpl/streamx/client.go b/tool/internal_pkg/tpl/streamx/client.go deleted file mode 100644 index 2a5e88ca2b..0000000000 --- a/tool/internal_pkg/tpl/streamx/client.go +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package streamx - -var ClientTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. - -package {{ToLower .ServiceName}} - -import ( - {{- range $path, $aliases := .Imports}} - {{- if not $aliases}} - "{{$path}}" - {{- else}} - {{- range $alias, $is := $aliases}} - {{$alias}} "{{$path}}" - {{- end}} - {{- end}} - {{- end}} -) -{{- $protocol := .Protocol | getStreamxRef}} - -type Client interface { -{{- range .AllMethods}} -{{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} -{{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} -{{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} -{{- $bidiSide := and .ClientStreaming .ServerStreaming}} -{{- $arg := index .Args 0}} - {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (r {{.Resp.Type}}, err error) - {{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.ClientStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) - {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (stream streamx.ServerStreamingClient[{{NotPtr .Resp.Type}}], err error) - {{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.BidiStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) - {{- end}} -{{- end}} -} - -func NewClient(destService string, opts ...streamxclient.Option) (Client, error) { - var options []streamxclient.Option - options = append(options, streamxclient.WithDestService(destService)) - options = append(options, opts...) - cp, err := {{$protocol}}.NewClientProvider(svcInfo) - if err != nil { - return nil, err - } - options = append(options, streamxclient.WithProvider(cp)) - cli, err := streamxclient.NewClient(svcInfo, options...) - if err != nil { - return nil, err - } - kc := &kClient{streamer: cli, caller: cli.(client.Client)} - return kc, nil -} - -var _ Client = (*kClient)(nil) - -type kClient struct { - caller client.Client - streamer streamxclient.Client -} - -{{- range .AllMethods}} -{{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} -{{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} -{{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} -{{- $bidiSide := and .ClientStreaming .ServerStreaming}} -{{- $mode := ""}} - {{- if $bidiSide -}} {{- $mode = "serviceinfo.StreamingBidirectional" }} - {{- else if $serverSide -}} {{- $mode = "serviceinfo.StreamingServer" }} - {{- else if $clientSide -}} {{- $mode = "serviceinfo.StreamingClient" }} - {{- else if $unary -}} {{- $mode = "serviceinfo.StreamingUnary" }} - {{- end}} -{{- $arg := index .Args 0}} -func (c *kClient) {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) ({{.Resp.Type}}, error) { - res := new({{NotPtr .Resp.Type}}) - _, err := streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, c.streamer, {{$mode}}, "{{.RawName}}", req, res, callOptions...) - if err != nil { - return nil, err - } - return res, nil -{{- else if $clientSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.ClientStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) { - return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, c.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) -{{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) (stream streamx.ServerStreamingClient[{{NotPtr .Resp.Type}}], err error) { - return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, c.streamer, {{$mode}}, "{{.RawName}}", req, nil, callOptions...) -{{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) (stream streamx.BidiStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], err error) { - return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, c.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) -{{- end}} -} -{{- end}} -` diff --git a/tool/internal_pkg/tpl/streamx/handler.method.go b/tool/internal_pkg/tpl/streamx/handler.method.go deleted file mode 100644 index acbbdd7571..0000000000 --- a/tool/internal_pkg/tpl/streamx/handler.method.go +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package streamx - -var HandlerMethodsTpl = `{{define "HandlerMethod"}} -{{- $protocol := .Protocol | getStreamxRef}} -{{- range .AllMethods}} -{{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} -{{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} -{{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} -{{- $bidiSide := and .ClientStreaming .ServerStreaming}} -{{- $arg := index .Args 0}} -func (s *{{.ServiceName}}Impl) {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}) (resp {{.Resp.Type}}, err error) { - {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (resp {{.Resp.Type}}, err error) { - {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{NotPtr .Resp.Type}}]) (err error) { - {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) (err error) { - {{- end}} - // TODO: Your code here... - return -} -{{- end}} -{{end}}{{/* define "HandlerMethod" */}} -` diff --git a/tool/internal_pkg/tpl/streamx/server.go b/tool/internal_pkg/tpl/streamx/server.go deleted file mode 100644 index 08eaa64483..0000000000 --- a/tool/internal_pkg/tpl/streamx/server.go +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package streamx - -var ServerTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. -package {{ToLower .ServiceName}} - -import ( - {{- range $path, $aliases := .Imports}} - {{- if not $aliases}} - "{{$path}}" - {{- else}} - {{- range $alias, $is := $aliases}} - {{$alias}} "{{$path}}" - {{- end}} - {{- end}} - {{- end}} -) -{{- $protocol := .Protocol | getStreamxRef}} - -type Server interface { -{{- range .AllMethods}} -{{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} -{{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} -{{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} -{{- $bidiSide := and .ClientStreaming .ServerStreaming}} -{{- $arg := index .Args 0}} - {{.Name}}{{- if $unary}}(ctx context.Context, req {{$arg.Type}}) ({{.Resp.Type}}, error) - {{- else if $clientSide}}(ctx context.Context, stream streamx.ClientStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) ({{.Resp.Type}}, error) - {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, stream streamx.ServerStreamingServer[{{NotPtr .Resp.Type}}]) error - {{- else if $bidiSide}}(ctx context.Context, stream streamx.BidiStreamingServer[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]) error - {{- end}} -{{- end}} -} - -func RegisterService(svr server.Server, handler Server, opts ...server.RegisterOption) error { - sp, err := {{$protocol}}.NewServerProvider(svcInfo) - if err != nil { - return err - } - nopts := []server.RegisterOption{ - streamxserver.WithProvider(sp), - } - nopts = append(nopts, opts...) - return svr.RegisterService(svcInfo, handler, nopts...) -} -` diff --git a/tool/internal_pkg/tpl/streamx/service.go b/tool/internal_pkg/tpl/streamx/service.go deleted file mode 100644 index d9bad47462..0000000000 --- a/tool/internal_pkg/tpl/streamx/service.go +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package streamx - -var ServiceTpl = `// Code generated by Kitex {{.Version}}. DO NOT EDIT. -package {{ToLower .ServiceName}} - -import ( - {{- range $path, $aliases := .Imports}} - {{- if not $aliases}} - "{{$path}}" - {{- else}} - {{- range $alias, $is := $aliases}} - {{$alias}} "{{$path}}" - {{- end}} - {{- end}} - {{- end}} -) -{{- $protocol := .Protocol | getStreamxRef}} - -var svcInfo = &serviceinfo.ServiceInfo{ - ServiceName: "{{.RawServiceName}}", - Methods: map[string]serviceinfo.MethodInfo{ - {{- range .AllMethods}} - {{- $unary := and (not .ServerStreaming) (not .ClientStreaming)}} - {{- $clientSide := and .ClientStreaming (not .ServerStreaming)}} - {{- $serverSide := and (not .ClientStreaming) .ServerStreaming}} - {{- $bidiSide := and .ClientStreaming .ServerStreaming}} - {{- $arg := index .Args 0}} - {{- $mode := ""}} - {{- if $bidiSide -}} {{- $mode = "serviceinfo.StreamingBidirectional" }} - {{- else if $serverSide -}} {{- $mode = "serviceinfo.StreamingServer" }} - {{- else if $clientSide -}} {{- $mode = "serviceinfo.StreamingClient" }} - {{- else if $unary -}} {{- $mode = "serviceinfo.StreamingUnary" }} - {{- end}} - "{{.RawName}}": serviceinfo.NewMethodInfo( - func(ctx context.Context, handler, reqArgs, resArgs interface{}) error { - return streamxserver.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, {{$mode}}, handler.(streamx.StreamHandler), reqArgs.(streamx.StreamReqArgs), resArgs.(streamx.StreamResArgs)) - }, - nil, - nil, - false, - serviceinfo.WithStreamingMode({{$mode}}), - ), - {{- end}} - }, - Extra: map[string]interface{}{ - "streaming": true, - "streamx": true, - }, -} - -` From 10e5bdfb208e5ac57b0d9d540c5d1bbf758e5205 Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Mon, 18 Nov 2024 17:06:10 +0800 Subject: [PATCH 28/34] feat: add ttheader streaming protocol --- transport/keys.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transport/keys.go b/transport/keys.go index 93032fae9c..43b39e4c84 100644 --- a/transport/keys.go +++ b/transport/keys.go @@ -29,9 +29,10 @@ const ( HTTP GRPC HESSIAN2 - JSONRPC + STREAMING - TTHeaderFramed = TTHeader | Framed + TTHeaderFramed = TTHeader | Framed + TTHeaderStreaming = TTHeader | STREAMING ) // Unknown indicates the protocol is unknown. @@ -54,6 +55,8 @@ func (tp Protocol) String() string { return "GRPC" case HESSIAN2: return "Hessian2" + case TTHeaderStreaming: + return "TTHeaderStreaming" } return Unknown } From e9a975046f07c7e114d449a00d86515f7e9f2ac5 Mon Sep 17 00:00:00 2001 From: Joway Date: Tue, 19 Nov 2024 14:04:22 +0800 Subject: [PATCH 29/34] fix: traceCtx should assgin to ctx (#1618) --- go.mod | 2 +- go.sum | 8 +- internal/server/remote_option.go | 3 + pkg/remote/trans/streamx/server_handler.go | 4 +- pkg/streamx/provider/ttstream/mock_test.go | 4 +- pkg/streamx/stream_args.go | 7 +- pkg/streamx/stream_middleware_internal.go | 12 +- pkg/streamx/streamx_gen_codec_test.go | 223 +++++++-------------- pkg/streamx/streamx_user_test.go | 61 +++--- server/server.go | 20 -- server/streamxserver/option.go | 10 + 11 files changed, 130 insertions(+), 224 deletions(-) diff --git a/go.mod b/go.mod index 030a562a8e..2a797a2e9f 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/cloudwego/dynamicgo v0.4.6-0.20241115162834-0e99bc39b128 github.com/cloudwego/fastpb v0.0.5 github.com/cloudwego/frugal v0.2.0 - github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682 + github.com/cloudwego/gopkg v0.1.3-0.20241118053554-db5d7d475e7e github.com/cloudwego/localsession v0.1.1 github.com/cloudwego/netpoll v0.6.5-0.20240911104114-8a1f5597a920 github.com/cloudwego/runtimex v0.1.0 diff --git a/go.sum b/go.sum index 81e9cf5cb6..ff4ca7421f 100644 --- a/go.sum +++ b/go.sum @@ -22,16 +22,12 @@ github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZ github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.2.0 h1:0ETSzQYoYqVvdl7EKjqJ9aJnDoG6TzvNKV3PMQiQTS8= github.com/cloudwego/frugal v0.2.0/go.mod h1:cpnV6kdRMjN3ylxRo63RNbZ9rBK6oxs70Zk6QZ4Enj4= -github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4 h1:SHw9GUBBcAnLWeK2MtPH7O6YQG9Q2ZZ8koD/4alpLvE= -github.com/cloudwego/gopkg v0.1.2-0.20240910075652-f542979ecca4/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= -github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682 h1:hj/AhlEngERp5Tjt864veEvyK6RglXKcXpxkIOSRfug= -github.com/cloudwego/gopkg v0.1.2-0.20240919030844-cb7123236682/go.mod h1:WoNTdXDPdvL97cBmRUWXVGkh2l2UFmpd9BUvbW2r0Aw= +github.com/cloudwego/gopkg v0.1.3-0.20241118053554-db5d7d475e7e h1:fRZIRv5bgpF/9TBlQMYwryV6d2BXDLw2MEghdzFecXY= +github.com/cloudwego/gopkg v0.1.3-0.20241118053554-db5d7d475e7e/go.mod h1:FQuXsRWRsSqJLsMVd5SYzp8/Z1y5gXKnVvRrWUOsCMI= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/localsession v0.1.1 h1:tbK7laDVrYfFDXoBXo4uCGMAxU4qmz2dDm8d4BGBnDo= github.com/cloudwego/localsession v0.1.1/go.mod h1:kiJxmvAcy4PLgKtEnPS5AXed3xCiXcs7Z+KBHP72Wv8= -github.com/cloudwego/netpoll v0.6.4 h1:z/dA4sOTUQof6zZIO4QNnLBXsDFFFEos9OOGloR6kno= -github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= github.com/cloudwego/netpoll v0.6.5-0.20240911104114-8a1f5597a920 h1:WT7vsDDb+ammyB7XLmNSS4vKGpPvM2JDl6h34Jj7mY4= github.com/cloudwego/netpoll v0.6.5-0.20240911104114-8a1f5597a920/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= github.com/cloudwego/runtimex v0.1.0 h1:HG+WxWoj5/CDChDZ7D99ROwvSMkuNXAqt6hnhTTZDiI= diff --git a/internal/server/remote_option.go b/internal/server/remote_option.go index 1fc22d71de..a814df152c 100644 --- a/internal/server/remote_option.go +++ b/internal/server/remote_option.go @@ -23,13 +23,16 @@ package server import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/trans/detection" "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" ) func newServerRemoteOption() *remote.ServerOption { return &remote.ServerOption{ TransServerFactory: netpoll.NewTransServerFactory(), + SvrHandlerFactory: detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()), Codec: codec.NewDefaultCodec(), Address: defaultAddress, ExitWaitTime: defaultExitWaitTime, diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index 92fe2193e1..be2a57fa22 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -120,11 +120,11 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) } wg.Add(1) // stream level goroutine - streamWorkerPool.GoCtx(ctx, func() { + streamWorkerPool.GoCtx(nctx, func() { defer wg.Done() err := t.OnStream(nctx, conn, ss) if err != nil && !errors.Is(err, io.EOF) { - klog.CtxErrorf(ctx, "KITEX: stream ReadStream failed: err=%v", err) + klog.CtxErrorf(nctx, "KITEX: stream ReadStream failed: err=%v", err) } }) } diff --git a/pkg/streamx/provider/ttstream/mock_test.go b/pkg/streamx/provider/ttstream/mock_test.go index 27701cf4b2..5d620f2ff8 100644 --- a/pkg/streamx/provider/ttstream/mock_test.go +++ b/pkg/streamx/provider/ttstream/mock_test.go @@ -20,8 +20,8 @@ import ( "fmt" "github.com/cloudwego/frugal" + "github.com/cloudwego/gopkg/protocol/thrift" - "github.com/cloudwego/kitex/pkg/protocol/bthrift" kutils "github.com/cloudwego/kitex/pkg/utils" ) @@ -34,7 +34,7 @@ func (p *testRequest) FastRead(buf []byte) (int, error) { return frugal.DecodeObject(buf, p) } -func (p *testRequest) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *testRequest) FastWriteNocopy(buf []byte, binaryWriter thrift.NocopyWriter) int { n, _ := frugal.EncodeObject(buf, binaryWriter, p) return n } diff --git a/pkg/streamx/stream_args.go b/pkg/streamx/stream_args.go index 9592edfd2c..f82c7b08c9 100644 --- a/pkg/streamx/stream_args.go +++ b/pkg/streamx/stream_args.go @@ -21,15 +21,14 @@ import ( "errors" ) -type StreamCtxKey struct{} +type streamCtxKey struct{} func WithStreamArgsContext(ctx context.Context, args StreamArgs) context.Context { - ctx = context.WithValue(ctx, StreamCtxKey{}, args) - return ctx + return context.WithValue(ctx, streamCtxKey{}, args) } func GetStreamArgsFromContext(ctx context.Context) (args StreamArgs) { - val := ctx.Value(StreamCtxKey{}) + val := ctx.Value(streamCtxKey{}) if val == nil { return nil } diff --git a/pkg/streamx/stream_middleware_internal.go b/pkg/streamx/stream_middleware_internal.go index a6d5d2daf1..96c6444aae 100644 --- a/pkg/streamx/stream_middleware_internal.go +++ b/pkg/streamx/stream_middleware_internal.go @@ -28,10 +28,10 @@ func NewStreamRecvStatMiddleware(traceCtx context.Context, ehandler EventHandler err = next(ctx, stream, res) // if traceCtx is nil, using the current ctx // otherwise, we should use the original trace ctx instead - if traceCtx == nil { - traceCtx = ctx + if traceCtx != nil { + ctx = traceCtx } - ehandler(traceCtx, stats.StreamRecv, err) + ehandler(ctx, stats.StreamRecv, err) return err } } @@ -43,10 +43,10 @@ func NewStreamSendStatMiddleware(traceCtx context.Context, ehandler EventHandler err = next(ctx, stream, res) // if traceCtx is nil, using the current ctx // otherwise, we should use the original trace ctx instead - if traceCtx == nil { - traceCtx = ctx + if traceCtx != nil { + ctx = traceCtx } - ehandler(traceCtx, stats.StreamSend, err) + ehandler(ctx, stats.StreamSend, err) return err } } diff --git a/pkg/streamx/streamx_gen_codec_test.go b/pkg/streamx/streamx_gen_codec_test.go index 67793454aa..7584969c51 100644 --- a/pkg/streamx/streamx_gen_codec_test.go +++ b/pkg/streamx/streamx_gen_codec_test.go @@ -22,10 +22,7 @@ import ( "reflect" "strings" - "github.com/apache/thrift/lib/go/thrift" - - "github.com/cloudwego/kitex/pkg/protocol/bthrift" - kutils "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/gopkg/protocol/thrift" ) // unused protection @@ -34,8 +31,6 @@ var ( _ = (*bytes.Buffer)(nil) _ = (*strings.Builder)(nil) _ = reflect.Type(nil) - _ = thrift.TProtocol(nil) - _ = bthrift.BinaryWriter(nil) ) var fieldIDToName_Request = map[int16]string{ @@ -84,14 +79,10 @@ func (p *Request) FastRead(buf []byte) (int, error) { var l int var fieldTypeId thrift.TType var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - + var issetType bool = false + var issetMessage bool = false for { - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -107,8 +98,9 @@ func (p *Request) FastRead(buf []byte) (int, error) { if err != nil { goto ReadFieldError } + issetType = true } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError @@ -121,59 +113,52 @@ func (p *Request) FastRead(buf []byte) (int, error) { if err != nil { goto ReadFieldError } + issetMessage = true } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError + + if !issetType { + fieldId = 1 + goto RequiredFieldNotSetError } + if !issetMessage { + fieldId = 2 + goto RequiredFieldNotSetError + } return offset, nil -ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Request[fieldId]), err) SkipFieldError: return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +RequiredFieldNotSetError: + return offset, thrift.NewProtocolException(thrift.INVALID_DATA, fmt.Sprintf("required field %s is not set", fieldIDToName_Request[fieldId])) } func (p *Request) FastReadField1(buf []byte) (int, error) { offset := 0 var _field int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - _field = v - } p.Type = _field return offset, nil @@ -183,108 +168,78 @@ func (p *Request) FastReadField2(buf []byte) (int, error) { offset := 0 var _field string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - _field = v - } p.Message = _field return offset, nil } -// for compatibility func (p *Request) FastWrite(buf []byte) int { - return 0 + return p.FastWriteNocopy(buf, nil) } -func (p *Request) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Request) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Request") if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) - offset += p.fastWriteField2(buf[offset:], binaryWriter) + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset } func (p *Request) BLength() int { l := 0 - l += bthrift.Binary.StructBeginLength("Request") if p != nil { l += p.field1Length() l += p.field2Length() } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l } -func (p *Request) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Request) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Type", thrift.I32, 1) - offset += bthrift.Binary.WriteI32(buf[offset:], p.Type) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 1) + offset += thrift.Binary.WriteI32(buf[offset:], p.Type) return offset } -func (p *Request) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Request) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Message", thrift.STRING, 2) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Message) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, p.Message) return offset } func (p *Request) field1Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("Type", thrift.I32, 1) - l += bthrift.Binary.I32Length(p.Type) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() return l } func (p *Request) field2Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("Message", thrift.STRING, 2) - l += bthrift.Binary.StringLengthNocopy(p.Message) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(p.Message) return l } -func (p *Request) DeepCopy(s interface{}) error { - src, ok := s.(*Request) - if !ok { - return fmt.Errorf("%T's type not matched %T", s, p) - } - - p.Type = src.Type - - if src.Message != "" { - p.Message = kutils.StringDeepCopy(src.Message) - } - - return nil -} - func (p *Response) FastRead(buf []byte) (int, error) { var err error var offset int var l int var fieldTypeId thrift.TType var fieldId int16 - _, l, err = bthrift.Binary.ReadStructBegin(buf) - offset += l - if err != nil { - goto ReadStructBeginError - } - + var issetType bool = false + var issetMessage bool = false for { - _, fieldTypeId, fieldId, l, err = bthrift.Binary.ReadFieldBegin(buf[offset:]) + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) offset += l if err != nil { goto ReadFieldBeginError @@ -300,8 +255,9 @@ func (p *Response) FastRead(buf []byte) (int, error) { if err != nil { goto ReadFieldError } + issetType = true } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError @@ -314,59 +270,52 @@ func (p *Response) FastRead(buf []byte) (int, error) { if err != nil { goto ReadFieldError } + issetMessage = true } else { - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } default: - l, err = bthrift.Binary.Skip(buf[offset:], fieldTypeId) + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l if err != nil { goto SkipFieldError } } - - l, err = bthrift.Binary.ReadFieldEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadFieldEndError - } } - l, err = bthrift.Binary.ReadStructEnd(buf[offset:]) - offset += l - if err != nil { - goto ReadStructEndError + + if !issetType { + fieldId = 1 + goto RequiredFieldNotSetError } + if !issetMessage { + fieldId = 2 + goto RequiredFieldNotSetError + } return offset, nil -ReadStructBeginError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Response[fieldId]), err) SkipFieldError: return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) -ReadFieldEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) -ReadStructEndError: - return offset, thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +RequiredFieldNotSetError: + return offset, thrift.NewProtocolException(thrift.INVALID_DATA, fmt.Sprintf("required field %s is not set", fieldIDToName_Response[fieldId])) } func (p *Response) FastReadField1(buf []byte) (int, error) { offset := 0 var _field int32 - if v, l, err := bthrift.Binary.ReadI32(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadI32(buf[offset:]); err != nil { return offset, err } else { offset += l - _field = v - } p.Type = _field return offset, nil @@ -376,90 +325,64 @@ func (p *Response) FastReadField2(buf []byte) (int, error) { offset := 0 var _field string - if v, l, err := bthrift.Binary.ReadString(buf[offset:]); err != nil { + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { return offset, err } else { offset += l - _field = v - } p.Message = _field return offset, nil } -// for compatibility func (p *Response) FastWrite(buf []byte) int { - return 0 + return p.FastWriteNocopy(buf, nil) } -func (p *Response) FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Response) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteStructBegin(buf[offset:], "Response") if p != nil { - offset += p.fastWriteField1(buf[offset:], binaryWriter) - offset += p.fastWriteField2(buf[offset:], binaryWriter) + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) } - offset += bthrift.Binary.WriteFieldStop(buf[offset:]) - offset += bthrift.Binary.WriteStructEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset } func (p *Response) BLength() int { l := 0 - l += bthrift.Binary.StructBeginLength("Response") if p != nil { l += p.field1Length() l += p.field2Length() } - l += bthrift.Binary.FieldStopLength() - l += bthrift.Binary.StructEndLength() + l += thrift.Binary.FieldStopLength() return l } -func (p *Response) fastWriteField1(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Response) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Type", thrift.I32, 1) - offset += bthrift.Binary.WriteI32(buf[offset:], p.Type) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.I32, 1) + offset += thrift.Binary.WriteI32(buf[offset:], p.Type) return offset } -func (p *Response) fastWriteField2(buf []byte, binaryWriter bthrift.BinaryWriter) int { +func (p *Response) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { offset := 0 - offset += bthrift.Binary.WriteFieldBegin(buf[offset:], "Message", thrift.STRING, 2) - offset += bthrift.Binary.WriteStringNocopy(buf[offset:], binaryWriter, p.Message) - offset += bthrift.Binary.WriteFieldEnd(buf[offset:]) + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, p.Message) return offset } func (p *Response) field1Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("Type", thrift.I32, 1) - l += bthrift.Binary.I32Length(p.Type) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.I32Length() return l } func (p *Response) field2Length() int { l := 0 - l += bthrift.Binary.FieldBeginLength("Message", thrift.STRING, 2) - l += bthrift.Binary.StringLengthNocopy(p.Message) - l += bthrift.Binary.FieldEndLength() + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(p.Message) return l } - -func (p *Response) DeepCopy(s interface{}) error { - src, ok := s.(*Response) - if !ok { - return fmt.Errorf("%T's type not matched %T", s, p) - } - - p.Type = src.Type - - if src.Message != "" { - p.Message = kutils.StringDeepCopy(src.Message) - } - - return nil -} diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index 6b223f5856..a410f794c8 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -141,7 +141,6 @@ func TestStreamingBasic(t *testing.T) { return err } }), - streamxserver.WithProvider(tc.ServerProvider), streamxserver.WithStreamRecvMiddleware(func(next streamx.StreamRecvEndpoint) streamx.StreamRecvEndpoint { return func(ctx context.Context, stream streamx.Stream, res any) (err error) { @@ -338,14 +337,14 @@ func TestStreamingBasic(t *testing.T) { } waitServerStreamDone(concurrency) wg.Wait() - test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(concurrency)) + untilEqual(t, &serverMiddlewareCount, int32(concurrency), time.Second) + untilEqual(t, &clientMiddlewareCount, int32(concurrency), time.Second) + untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) + untilEqual(t, &clientStreamCount, int32(concurrency), time.Second) + untilEqual(t, &serverRecvCount, int32(concurrency), time.Second) + untilEqual(t, &clientRecvCount, int32(concurrency), time.Second) + untilEqual(t, &serverSendCount, int32(concurrency), time.Second) + untilEqual(t, &clientSendCount, int32(concurrency), time.Second) resetServerCount() resetClientCount() @@ -372,15 +371,14 @@ func TestStreamingBasic(t *testing.T) { } waitServerStreamDone(concurrency) wg.Wait() + untilEqual(t, &serverMiddlewareCount, int32(concurrency), time.Second) + untilEqual(t, &clientMiddlewareCount, int32(concurrency), time.Second) untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) - test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round*concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(round*concurrency)) + untilEqual(t, &clientStreamCount, int32(concurrency), time.Second) + untilEqual(t, &serverRecvCount, int32(round*concurrency), time.Second) + untilEqual(t, &clientRecvCount, int32(concurrency), time.Second) + untilEqual(t, &serverSendCount, int32(concurrency), time.Second) + untilEqual(t, &clientSendCount, int32(round*concurrency), time.Second) resetServerCount() resetClientCount() @@ -409,15 +407,14 @@ func TestStreamingBasic(t *testing.T) { } waitServerStreamDone(concurrency) wg.Wait() + untilEqual(t, &serverMiddlewareCount, int32(concurrency), time.Second) + untilEqual(t, &clientMiddlewareCount, int32(concurrency), time.Second) untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) - test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(round*concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(round*concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(concurrency)) + untilEqual(t, &clientStreamCount, int32(concurrency), time.Second) + untilEqual(t, &serverRecvCount, int32(concurrency), time.Second) + untilEqual(t, &clientRecvCount, int32(round*concurrency), time.Second) + untilEqual(t, &serverSendCount, int32(round*concurrency), time.Second) + untilEqual(t, &clientSendCount, int32(concurrency), time.Second) resetServerCount() resetClientCount() @@ -460,16 +457,14 @@ func TestStreamingBasic(t *testing.T) { } waitServerStreamDone(concurrency) wg.Wait() + untilEqual(t, &serverMiddlewareCount, int32(concurrency), time.Second) + untilEqual(t, &clientMiddlewareCount, int32(concurrency), time.Second) untilEqual(t, &serverStreamCount, int32(concurrency), time.Second) untilEqual(t, &clientStreamCount, int32(concurrency), time.Second) - test.DeepEqual(t, atomic.LoadInt32(&serverMiddlewareCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientMiddlewareCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverStreamCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientStreamCount), int32(concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverRecvCount), int32(round*concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientRecvCount), int32(round*concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&serverSendCount), int32(round*concurrency)) - test.DeepEqual(t, atomic.LoadInt32(&clientSendCount), int32(round*concurrency)) + untilEqual(t, &serverRecvCount, int32(round*concurrency), time.Second) + untilEqual(t, &clientRecvCount, int32(round*concurrency), time.Second) + untilEqual(t, &serverSendCount, int32(round*concurrency), time.Second) + untilEqual(t, &clientSendCount, int32(round*concurrency), time.Second) resetServerCount() resetClientCount() diff --git a/server/server.go b/server/server.go index cb68710aa6..40e335b8da 100644 --- a/server/server.go +++ b/server/server.go @@ -29,10 +29,6 @@ import ( "github.com/cloudwego/localsession/backup" - "github.com/cloudwego/kitex/pkg/remote/trans/detection" - "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" - "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" - internal_server "github.com/cloudwego/kitex/internal/server" "github.com/cloudwego/kitex/pkg/acl" "github.com/cloudwego/kitex/pkg/diagnosis" @@ -46,7 +42,6 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/bound" "github.com/cloudwego/kitex/pkg/remote/remotesvr" - streamxstrans "github.com/cloudwego/kitex/pkg/remote/trans/streamx" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" @@ -546,21 +541,6 @@ func doAddBoundHandler(h remote.BoundHandler, opt *remote.ServerOption) { func (s *server) newSvrTransHandler() (handler remote.ServerTransHandler, err error) { transHdlrFactory := s.opt.RemoteOpt.SvrHandlerFactory - if transHdlrFactory == nil { - candidateFactories := make([]remote.ServerTransHandlerFactory, 0) - if s.opt.StreamX.Provider != nil { - candidateFactories = append(candidateFactories, - streamxstrans.NewSvrTransHandlerFactory(s.opt.StreamX.Provider), - ) - } - candidateFactories = append(candidateFactories, - nphttp2.NewSvrTransHandlerFactory(), - ) - transHdlrFactory = detection.NewSvrTransHandlerFactory( - netpoll.NewSvrTransHandlerFactory(), - candidateFactories..., - ) - } transHdlr, err := transHdlrFactory.NewTransHandler(s.opt.RemoteOpt) if err != nil { return nil, err diff --git a/server/streamxserver/option.go b/server/streamxserver/option.go index 2e765e1dd6..b097c91a59 100644 --- a/server/streamxserver/option.go +++ b/server/streamxserver/option.go @@ -18,6 +18,10 @@ package streamxserver import ( internal_server "github.com/cloudwego/kitex/internal/server" + "github.com/cloudwego/kitex/pkg/remote/trans/detection" + "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" + streamxstrans "github.com/cloudwego/kitex/pkg/remote/trans/streamx" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/kitex/server" @@ -25,7 +29,13 @@ import ( func WithProvider(provider streamx.ServerProvider) server.Option { return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { + // streamx provider server trans handler can use with other protocol o.StreamX.Provider = provider + o.RemoteOpt.SvrHandlerFactory = detection.NewSvrTransHandlerFactory( + netpoll.NewSvrTransHandlerFactory(), + streamxstrans.NewSvrTransHandlerFactory(provider), + nphttp2.NewSvrTransHandlerFactory(), + ) }} } From dc99b5a23f839a33f41c1d9f3a104c6b014a2493 Mon Sep 17 00:00:00 2001 From: Joway Date: Wed, 20 Nov 2024 15:26:29 +0800 Subject: [PATCH 30/34] fix: use old client api (#1620) --- tool/internal_pkg/tpl/service.go | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tool/internal_pkg/tpl/service.go b/tool/internal_pkg/tpl/service.go index 745eb802af..6a5c74cc26 100644 --- a/tool/internal_pkg/tpl/service.go +++ b/tool/internal_pkg/tpl/service.go @@ -498,17 +498,11 @@ func (p *{{.ResStructName}}) GetResult() interface{} { type kClient struct { c client.Client - {{- if and .StreamX .HasStreaming}} - streamer client.StreamX - {{- end}} } func newServiceClient(c client.Client) *kClient { return &kClient{ c: c, - {{- if and .StreamX .HasStreaming}} - streamer: c.(client.StreamX), - {{- end}} } } @@ -529,7 +523,7 @@ func newServiceClient(c client.Client) *kClient { func (p *kClient) {{.Name}}{{- if $streamingUnary}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) ({{.Resp.Type}}, error) { res := new({{NotPtr .Resp.Type}}) _, _, err := streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, p.streamer, {{$mode}}, "{{.RawName}}", req, res, callOptions...) + ctx, p.c, {{$mode}}, "{{.RawName}}", req, res, callOptions...) if err != nil { return nil, err } @@ -539,19 +533,19 @@ func (p *kClient) {{.Name}}{{- if $streamingUnary}}(ctx context.Context, req {{$ context.Context, streamx.ClientStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], error, ) { return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, p.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) + ctx, p.c, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) } {{- else if $serverSide}}(ctx context.Context, req {{$arg.Type}}, callOptions ...streamxcallopt.CallOption) ( context.Context, streamx.ServerStreamingClient[{{NotPtr .Resp.Type}}], error, ) { return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, p.streamer, {{$mode}}, "{{.RawName}}", req, nil, callOptions...) + ctx, p.c, {{$mode}}, "{{.RawName}}", req, nil, callOptions...) } {{- else if $bidiSide}}(ctx context.Context, callOptions ...streamxcallopt.CallOption) ( context.Context, streamx.BidiStreamingClient[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}], error, ) { return streamxclient.InvokeStream[{{NotPtr $arg.Type}}, {{NotPtr .Resp.Type}}]( - ctx, p.streamer, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) + ctx, p.c, {{$mode}}, "{{.RawName}}", nil, nil, callOptions...) } {{- end}}{{/* if $streamingUnary end */}} {{- else}}{{- /* old streaming interface */}} From b0f3a460c47dd1a479a113b6a30bdfbbea6546f3 Mon Sep 17 00:00:00 2001 From: Joway Date: Wed, 20 Nov 2024 17:56:27 +0800 Subject: [PATCH 31/34] fix: gen normal code in streamx mode (#1621) --- tool/internal_pkg/tpl/client.go | 16 ++++++---------- tool/internal_pkg/tpl/service.go | 22 +--------------------- 2 files changed, 7 insertions(+), 31 deletions(-) diff --git a/tool/internal_pkg/tpl/client.go b/tool/internal_pkg/tpl/client.go index 6044775ed9..b362b26daa 100644 --- a/tool/internal_pkg/tpl/client.go +++ b/tool/internal_pkg/tpl/client.go @@ -29,14 +29,14 @@ import ( {{- end}} {{- end}} {{- if .HasStreaming}} - {{- if not .StreamX}} + {{- if not .StreamX}} "github.com/cloudwego/kitex/client/streamclient" "github.com/cloudwego/kitex/client/callopt/streamcall" - {{- else}} - "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" - "github.com/cloudwego/kitex/pkg/streamx" - {{- end}}{{- /* if not .StreamX end */}} - {{- end}} + {{- else}} + "github.com/cloudwego/kitex/client/streamxclient/streamxcallopt" + "github.com/cloudwego/kitex/pkg/streamx" + {{- end}}{{- /* if not .StreamX end */}} + {{- end}} ) // Client is designed to provide IDL-compatible methods with call-option parameter for kitex framework. type Client interface { @@ -107,13 +107,9 @@ func NewClient(destService string, opts ...client.Option) (Client, error) { {{end}} options = append(options, opts...) - {{- if and .StreamX .HasStreaming}} - kc, err := client.NewClient(serviceInfo(), options...) - {{- else}} kc, err := client.NewClient( {{- if eq $.Codec "protobuf"}}serviceInfo(){{else}}serviceInfoForClient(){{end -}} , options...) - {{- end}}{{/* if .StreamX .HasStreaming end */}} if err != nil { return nil, err } diff --git a/tool/internal_pkg/tpl/service.go b/tool/internal_pkg/tpl/service.go index 6a5c74cc26..d574c5d8ea 100644 --- a/tool/internal_pkg/tpl/service.go +++ b/tool/internal_pkg/tpl/service.go @@ -65,36 +65,17 @@ var serviceMethods = map[string]kitex.MethodInfo{ {{- end}} } -{{- if and .StreamX .HasStreaming}} -var {{LowerFirst .ServiceName}}ServiceInfo = NewServiceInfo() -{{- else}} var ( {{LowerFirst .ServiceName}}ServiceInfo = NewServiceInfo() {{LowerFirst .ServiceName}}ServiceInfoForClient = NewServiceInfoForClient() {{LowerFirst .ServiceName}}ServiceInfoForStreamClient = NewServiceInfoForStreamClient() ) -{{- end}} {{- /* if and .StreamX .HasStreaming end */}} // for server func serviceInfo() *kitex.ServiceInfo { return {{LowerFirst .ServiceName}}ServiceInfo } -{{- if and .StreamX .HasStreaming}} -// NewServiceInfo creates a new ServiceInfo containing all methods -{{- /* It's for the Server (providing both streaming/non-streaming APIs), or for the grpc client */}} -func NewServiceInfo() *kitex.ServiceInfo { - return newServiceInfo() -} - -func newServiceInfo() *kitex.ServiceInfo { - return &kitex.ServiceInfo{ - ServiceName: "{{.RawServiceName}}", - PayloadCodec: kitex.Thrift, - Methods: serviceMethods, - } -} -{{- else}} {{- /* old streaming interface */}} // for stream client func serviceInfoForStreamClient() *kitex.ServiceInfo { return {{LowerFirst .ServiceName}}ServiceInfoForStreamClient @@ -108,7 +89,7 @@ func serviceInfoForClient() *kitex.ServiceInfo { // NewServiceInfo creates a new ServiceInfo containing all methods {{- /* It's for the Server (providing both streaming/non-streaming APIs), or for the grpc client */}} func NewServiceInfo() *kitex.ServiceInfo { - return newServiceInfo({{- if .HasStreaming}}true{{else}}false{{end}}, true, true) + return newServiceInfo({{- .HasStreaming }}, true, true) } // NewServiceInfo creates a new ServiceInfo containing non-streaming methods @@ -162,7 +143,6 @@ func newServiceInfo(hasStreaming bool, keepStreamingMethods bool, keepNonStreami } return svcInfo } -{{- end}}{{- /* if and .StreamX .HasStreaming end */}} {{range .AllMethods}} {{- $isStreaming := or .ClientStreaming .ServerStreaming}} From c3b67291779920b2386cfb5791c3a38f353fa496 Mon Sep 17 00:00:00 2001 From: Joway Date: Wed, 27 Nov 2024 16:53:31 +0800 Subject: [PATCH 32/34] refactor: rm sinfo from provider api (#1626) --- client/streamxclient/client_option.go | 4 ++ internal/server/remote_option.go | 7 +++- internal/streamx/stream.go | 5 --- pkg/remote/remotecli/stream.go | 4 +- pkg/remote/trans/streamx/server_handler.go | 16 +++++-- pkg/streamx/client_provider_internal.go | 40 ------------------ .../provider/ttstream/client_provier.go | 11 ++--- .../provider/ttstream/client_trans_pool.go | 4 +- .../ttstream/client_trans_pool_longconn.go | 5 +-- .../ttstream/client_trans_pool_muxconn.go | 9 ++-- .../ttstream/client_trans_pool_shortconn.go | 5 +-- .../provider/ttstream/server_provider.go | 9 ++-- pkg/streamx/provider/ttstream/stream.go | 8 +--- pkg/streamx/provider/ttstream/test_utils.go | 4 +- pkg/streamx/provider/ttstream/transport.go | 21 ++++------ .../provider/ttstream/transport_test.go | 12 +++--- pkg/streamx/server_provider_internal.go | 42 ------------------- pkg/streamx/streamx_user_test.go | 8 ++-- server/streamxserver/option.go | 5 ++- 19 files changed, 63 insertions(+), 156 deletions(-) delete mode 100644 pkg/streamx/client_provider_internal.go delete mode 100644 pkg/streamx/server_provider_internal.go diff --git a/client/streamxclient/client_option.go b/client/streamxclient/client_option.go index 9c55c3fdcb..4d8a79bae1 100644 --- a/client/streamxclient/client_option.go +++ b/client/streamxclient/client_option.go @@ -24,24 +24,28 @@ import ( "github.com/cloudwego/kitex/pkg/utils" ) +// Deprecated: Note that it maybe refactor in the next version func WithProvider(pvd streamx.ClientProvider) internal_client.Option { return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { o.RemoteOpt.Provider = pvd }} } +// WithStreamRecvTimeout add recv timeout for stream.Recv function func WithStreamRecvTimeout(timeout time.Duration) internal_client.Option { return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { o.StreamX.RecvTimeout = timeout }} } +// WithStreamRecvMiddleware add recv middleware func WithStreamRecvMiddleware(smw streamx.StreamRecvMiddleware) internal_client.Option { return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { o.StreamX.StreamRecvMWs = append(o.StreamX.StreamRecvMWs, smw) }} } +// WithStreamSendMiddleware add send middleware func WithStreamSendMiddleware(smw streamx.StreamSendMiddleware) internal_client.Option { return internal_client.Option{F: func(o *internal_client.Options, di *utils.Slice) { o.StreamX.StreamSendMWs = append(o.StreamX.StreamSendMWs, smw) diff --git a/internal/server/remote_option.go b/internal/server/remote_option.go index a814df152c..fabd520d1a 100644 --- a/internal/server/remote_option.go +++ b/internal/server/remote_option.go @@ -31,8 +31,11 @@ import ( func newServerRemoteOption() *remote.ServerOption { return &remote.ServerOption{ - TransServerFactory: netpoll.NewTransServerFactory(), - SvrHandlerFactory: detection.NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()), + TransServerFactory: netpoll.NewTransServerFactory(), + SvrHandlerFactory: detection.NewSvrTransHandlerFactory( + netpoll.NewSvrTransHandlerFactory(), + nphttp2.NewSvrTransHandlerFactory(), + ), Codec: codec.NewDefaultCodec(), Address: defaultAddress, ExitWaitTime: defaultExitWaitTime, diff --git a/internal/streamx/stream.go b/internal/streamx/stream.go index af94d9d424..681b974b52 100644 --- a/internal/streamx/stream.go +++ b/internal/streamx/stream.go @@ -16,12 +16,7 @@ package streamx -import ( - "github.com/cloudwego/kitex/pkg/serviceinfo" -) - type StreamInfo interface { Service() string Method() string - Mode() serviceinfo.StreamingMode } diff --git a/pkg/remote/remotecli/stream.go b/pkg/remote/remotecli/stream.go index 1ea1df172d..bf88e34e25 100644 --- a/pkg/remote/remotecli/stream.go +++ b/pkg/remote/remotecli/stream.go @@ -23,7 +23,6 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/streamx" ) // NewStream create a client side stream @@ -39,8 +38,7 @@ func NewStream(ctx context.Context, ri rpcinfo.RPCInfo, handler remote.ClientTra // streamx provider if opt.Provider != nil { // wrap internal client provider - clientProvider := streamx.NewClientProvider(opt.Provider) - cs, err := clientProvider.NewStream(ctx, ri) + cs, err := opt.Provider.NewStream(ctx, ri) if err != nil { return nil, nil, err } diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index be2a57fa22..ef35978d44 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -19,6 +19,7 @@ package streamx import ( "context" "errors" + "fmt" "io" "net" "runtime/debug" @@ -52,8 +53,7 @@ type svrTransHandlerFactory struct { // NewSvrTransHandlerFactory ... func NewSvrTransHandlerFactory(provider streamx.ServerProvider) remote.ServerTransHandlerFactory { - sp := streamx.NewServerProvider(provider) // wrapped server provider - return &svrTransHandlerFactory{provider: sp} + return &svrTransHandlerFactory{provider: provider} } func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { @@ -155,9 +155,17 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream ink := ri.Invocation().(rpcinfo.InvocationSetter) if si, ok := ss.(istreamx.StreamInfo); ok { - ink.SetServiceName(si.Service()) + sinfo := t.opt.SvcSearcher.SearchService(si.Service(), si.Method(), false) + if sinfo == nil { + return remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", si.Service())) + } + minfo := sinfo.MethodInfo(si.Method()) + if minfo == nil { + return remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", si.Method())) + } + ink.SetServiceName(sinfo.ServiceName) ink.SetMethodName(si.Method()) - ink.SetStreamingMode(si.Mode()) + ink.SetStreamingMode(minfo.StreamingMode()) if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil { _ = mutableTo.SetMethod(si.Method()) } diff --git a/pkg/streamx/client_provider_internal.go b/pkg/streamx/client_provider_internal.go deleted file mode 100644 index 7f80aa15be..0000000000 --- a/pkg/streamx/client_provider_internal.go +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package streamx - -import ( - "context" - - "github.com/cloudwego/kitex/pkg/rpcinfo" -) - -// NewClientProvider wrap specific client provider -func NewClientProvider(cs ClientProvider) ClientProvider { - return internalClientProvider{ClientProvider: cs} -} - -type internalClientProvider struct { - ClientProvider -} - -func (p internalClientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (ClientStream, error) { - cs, err := p.ClientProvider.NewStream(ctx, ri) - if err != nil { - return nil, err - } - return cs, nil -} diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index 73ccac84b4..f4029dc00b 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -25,7 +25,6 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" ) @@ -33,19 +32,17 @@ import ( var _ streamx.ClientProvider = (*clientProvider)(nil) // NewClientProvider return a client provider -func NewClientProvider(sinfo *serviceinfo.ServiceInfo, opts ...ClientProviderOption) (streamx.ClientProvider, error) { +func NewClientProvider(opts ...ClientProviderOption) streamx.ClientProvider { cp := new(clientProvider) - cp.sinfo = sinfo cp.transPool = newMuxConnTransPool(DefaultMuxConnConfig) for _, opt := range opts { opt(cp) } - return cp, nil + return cp } type clientProvider struct { transPool transPool - sinfo *serviceinfo.ServiceInfo metaHandler MetaFrameHandler headerHandler HeaderFrameWriteHandler @@ -78,10 +75,10 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (stre if strHeader == nil { strHeader = map[string]string{} } - strHeader[ttheader.HeaderIDLServiceName] = c.sinfo.ServiceName + strHeader[ttheader.HeaderIDLServiceName] = invocation.ServiceName() metainfo.SaveMetaInfoToMap(ctx, strHeader) - trans, err := c.transPool.Get(c.sinfo, addr.Network(), addr.String()) + trans, err := c.transPool.Get(addr.Network(), addr.String()) if err != nil { return nil, err } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool.go b/pkg/streamx/provider/ttstream/client_trans_pool.go index 7ec587d93d..e5c4706463 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool.go @@ -18,13 +18,11 @@ package ttstream import ( "github.com/cloudwego/netpoll" - - "github.com/cloudwego/kitex/pkg/serviceinfo" ) var dialer = netpoll.NewDialer() type transPool interface { - Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) + Get(network, addr string) (trans *transport, err error) Put(trans *transport) } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go index 87c9109c79..5477f30917 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go @@ -19,7 +19,6 @@ package ttstream import ( "time" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" ) @@ -46,7 +45,7 @@ type longConnTransPool struct { config LongConnConfig } -func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) { +func (c *longConnTransPool) Get(network, addr string) (trans *transport, err error) { for { o := c.transPool.Pop(addr) if o == nil { @@ -63,7 +62,7 @@ func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr st if err != nil { return nil, err } - trans = newTransport(clientTransport, sinfo, conn, c) + trans = newTransport(clientTransport, conn, c) // create new transport return trans, nil } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go index 5d75356ef2..58dcbd6efa 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go @@ -27,7 +27,6 @@ import ( "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/pkg/gofunc" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) @@ -77,7 +76,7 @@ func (tl *muxConnTransList) Close() { tl.L.Unlock() } -func (tl *muxConnTransList) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { +func (tl *muxConnTransList) Get(network, addr string) (*transport, error) { // fast path idx := atomic.AddUint32(&tl.cursor, 1) % uint32(tl.size) tl.L.RLock() @@ -103,7 +102,7 @@ func (tl *muxConnTransList) Get(sinfo *serviceinfo.ServiceInfo, network, addr st if err != nil { return nil, err } - trans = newTransport(clientTransport, sinfo, conn, tl.pool) + trans = newTransport(clientTransport, conn, tl.pool) _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { // peer close _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("connection closed by peer"))) @@ -132,13 +131,13 @@ type muxConnTransPool struct { cleanerOnce int32 } -func (p *muxConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (trans *transport, err error) { +func (p *muxConnTransPool) Get(network, addr string) (trans *transport, err error) { v, ok := p.pool.Load(addr) if !ok { // multi concurrent Get should get the same TransList object v, _ = p.pool.LoadOrStore(addr, newMuxConnTransList(p.config.PoolSize, p)) } - return v.(*muxConnTransList).Get(sinfo, network, addr) + return v.(*muxConnTransList).Get(network, addr) } func (p *muxConnTransPool) Put(trans *transport) { diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go index d5bd7e463f..5e833885b8 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go @@ -20,7 +20,6 @@ import ( "errors" "time" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) @@ -30,14 +29,14 @@ func newShortConnTransPool() transPool { type shortConnTransPool struct{} -func (p *shortConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network, addr string) (*transport, error) { +func (p *shortConnTransPool) Get(network, addr string) (*transport, error) { // create new connection conn, err := dialer.DialConnection(network, addr, time.Second) if err != nil { return nil, err } // create new transport - trans := newTransport(clientTransport, sinfo, conn, p) + trans := newTransport(clientTransport, conn, p) return trans, nil } diff --git a/pkg/streamx/provider/ttstream/server_provider.go b/pkg/streamx/provider/ttstream/server_provider.go index 0e032c3b06..662495c9b7 100644 --- a/pkg/streamx/provider/ttstream/server_provider.go +++ b/pkg/streamx/provider/ttstream/server_provider.go @@ -29,7 +29,6 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/ktx" "github.com/cloudwego/kitex/pkg/utils" @@ -43,17 +42,15 @@ type ( var _ streamx.ServerProvider = (*serverProvider)(nil) // NewServerProvider return a server provider -func NewServerProvider(sinfo *serviceinfo.ServiceInfo, opts ...ServerProviderOption) (streamx.ServerProvider, error) { +func NewServerProvider(opts ...ServerProviderOption) streamx.ServerProvider { sp := new(serverProvider) - sp.sinfo = sinfo for _, opt := range opts { opt(sp) } - return sp, nil + return sp } type serverProvider struct { - sinfo *serviceinfo.ServiceInfo metaHandler MetaFrameHandler headerHandler HeaderFrameReadHandler } @@ -74,7 +71,7 @@ func (s serverProvider) Available(ctx context.Context, conn net.Conn) bool { // OnActive will be called when a connection accepted func (s serverProvider) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { nconn := conn.(netpoll.Connection) - trans := newTransport(serverTransport, s.sinfo, nconn, nil) + trans := newTransport(serverTransport, nconn, nil) _ = nconn.(onDisConnectSetter).SetOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { // server only close transport when peer connection closed _ = trans.Close(nil) diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 7e0cca2c03..30f6e5ab7a 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -42,13 +42,12 @@ var ( _ StreamMeta = (*stream)(nil) ) -func newStream(ctx context.Context, writer streamWriter, mode streamx.StreamingMode, smeta streamFrame) *stream { +func newStream(ctx context.Context, writer streamWriter, smeta streamFrame) *stream { s := new(stream) s.streamFrame = smeta s.StreamMeta = newStreamMeta() s.reader = newStreamReader() s.writer = writer - s.mode = mode s.wheader = make(streamx.Header) s.wtrailer = make(streamx.Trailer) s.headerSig = make(chan int32, 1) @@ -84,7 +83,6 @@ type stream struct { StreamMeta reader *streamReader writer streamWriter - mode streamx.StreamingMode wheader streamx.Header // wheader == nil means it already be sent wtrailer streamx.Trailer // wtrailer == nil means it already be sent @@ -100,10 +98,6 @@ type stream struct { closeCallback []streamxcallopt.StreamCloseCallback } -func (s *stream) Mode() streamx.StreamingMode { - return s.mode -} - func (s *stream) Service() string { if len(s.header) == 0 { return "" diff --git a/pkg/streamx/provider/ttstream/test_utils.go b/pkg/streamx/provider/ttstream/test_utils.go index e78ad68738..823c6b55d7 100644 --- a/pkg/streamx/provider/ttstream/test_utils.go +++ b/pkg/streamx/provider/ttstream/test_utils.go @@ -40,12 +40,12 @@ func newTestStreamPipe(sinfo *serviceinfo.ServiceInfo, method string) (*clientSt intHeader := make(IntHeader) strHeader := make(streamx.Header) - ctrans := newTransport(clientTransport, sinfo, cconn, nil) + ctrans := newTransport(clientTransport, cconn, nil) rawClientStream, err := ctrans.WriteStream(context.Background(), method, intHeader, strHeader) if err != nil { return nil, nil, err } - strans := newTransport(serverTransport, sinfo, sconn, nil) + strans := newTransport(serverTransport, sconn, nil) rawServerStream, err := strans.ReadStream(context.Background()) if err != nil { return nil, nil, err diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index e3dab54cb8..d3b87ad7fb 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -30,7 +30,6 @@ import ( "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" - "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" ) @@ -49,10 +48,9 @@ func isIgnoreError(err error) bool { // transport is used to read/write frames and disturbed frames to different streams type transport struct { - kind int32 - sinfo *serviceinfo.ServiceInfo - conn netpoll.Connection - pool transPool + kind int32 + conn netpoll.Connection + pool transPool // transport should operate directly on stream streams sync.Map // key=streamID val=stream scache []*stream // size is streamCacheSize @@ -62,12 +60,11 @@ type transport struct { closedTrigger chan struct{} } -func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection, pool transPool) *transport { +func newTransport(kind int32, conn netpoll.Connection, pool transPool) *transport { // TODO: let it configurable _ = conn.SetReadTimeout(0) t := &transport{ kind: kind, - sinfo: sinfo, conn: conn, pool: pool, streams: sync.Map{}, @@ -94,7 +91,7 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne t.closedTrigger <- struct{}{} }() err = t.loopRead() - }, gofunc.NewBasicInfo(sinfo.ServiceName, addr)) + }, gofunc.NewBasicInfo("", addr)) gofunc.RecoverGoFuncWithInfo(context.Background(), func() { var err error defer func() { @@ -107,7 +104,7 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne t.closedTrigger <- struct{}{} }() err = t.loopWrite() - }, gofunc.NewBasicInfo(sinfo.ServiceName, addr)) + }, gofunc.NewBasicInfo("", addr)) return t } @@ -184,8 +181,7 @@ func (t *transport) readFrame(reader bufiox.Reader) error { var s *stream if fr.typ == headerFrameType && t.kind == serverTransport { // server recv a header frame, we should create a new stream - smode := t.sinfo.MethodInfo(fr.method).StreamingMode() - s = newStream(context.Background(), t, smode, fr.streamFrame) + s = newStream(context.Background(), t, fr.streamFrame) t.storeStream(s) err = t.spipe.Write(context.Background(), s) } else { @@ -298,9 +294,8 @@ func (t *transport) WriteStream( } sid := genStreamID() - smode := t.sinfo.MethodInfo(method).StreamingMode() // new stream first - s := newStream(ctx, t, smode, streamFrame{sid: sid, method: method}) + s := newStream(ctx, t, streamFrame{sid: sid, method: method}) t.storeStream(s) // send create stream request for server fr := newFrame(streamFrame{sid: sid, method: method, header: strHeader, meta: intHeader}, headerFrameType, nil) diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index a994d0a99c..a15ef6e5f2 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -71,10 +71,10 @@ func TestTransportBasic(t *testing.T) { intHeader[0] = "test" strHeader := make(streamx.Header) strHeader["key"] = "val" - ctrans := newTransport(clientTransport, testServiceInfo, cconn, nil) + ctrans := newTransport(clientTransport, cconn, nil) rawClientStream, err := ctrans.WriteStream(context.Background(), "Bidi", intHeader, strHeader) test.Assert(t, err == nil, err) - strans := newTransport(serverTransport, testServiceInfo, sconn, nil) + strans := newTransport(serverTransport, sconn, nil) rawServerStream, err := strans.ReadStream(context.Background()) test.Assert(t, err == nil, err) @@ -139,10 +139,10 @@ func TestTransportServerStreaming(t *testing.T) { intHeader := make(IntHeader) strHeader := make(streamx.Header) - ctrans := newTransport(clientTransport, testServiceInfo, cconn, nil) + ctrans := newTransport(clientTransport, cconn, nil) rawClientStream, err := ctrans.WriteStream(context.Background(), "Bidi", intHeader, strHeader) test.Assert(t, err == nil, err) - strans := newTransport(serverTransport, testServiceInfo, sconn, nil) + strans := newTransport(serverTransport, sconn, nil) rawServerStream, err := strans.ReadStream(context.Background()) test.Assert(t, err == nil, err) @@ -203,10 +203,10 @@ func TestTransportException(t *testing.T) { test.Assert(t, err == nil, err) // server send data - ctrans := newTransport(clientTransport, testServiceInfo, cconn, nil) + ctrans := newTransport(clientTransport, cconn, nil) rawClientStream, err := ctrans.WriteStream(context.Background(), "Bidi", make(IntHeader), make(streamx.Header)) test.Assert(t, err == nil, err) - strans := newTransport(serverTransport, testServiceInfo, sconn, nil) + strans := newTransport(serverTransport, sconn, nil) rawServerStream, err := strans.ReadStream(context.Background()) test.Assert(t, err == nil, err) cStream := newClientStream(rawClientStream) diff --git a/pkg/streamx/server_provider_internal.go b/pkg/streamx/server_provider_internal.go deleted file mode 100644 index f8d4be0c9e..0000000000 --- a/pkg/streamx/server_provider_internal.go +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package streamx - -import ( - "context" - "net" -) - -// NewServerProvider wrap specific server provider -func NewServerProvider(ss ServerProvider) ServerProvider { - if _, ok := ss.(*internalServerProvider); ok { - return ss - } - return internalServerProvider{ServerProvider: ss} -} - -type internalServerProvider struct { - ServerProvider -} - -func (p internalServerProvider) OnStream(ctx context.Context, conn net.Conn) (context.Context, ServerStream, error) { - ctx, ss, err := p.ServerProvider.OnStream(ctx, conn) - if err != nil { - return nil, nil, err - } - return ctx, ss, nil -} diff --git a/pkg/streamx/streamx_user_test.go b/pkg/streamx/streamx_user_test.go index a410f794c8..50c48b8540 100644 --- a/pkg/streamx/streamx_user_test.go +++ b/pkg/streamx/streamx_user_test.go @@ -55,12 +55,12 @@ type testCase struct { } func init() { - sp, _ := ttstream.NewServerProvider(testServiceInfo) - cp, _ := ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientLongConnPool(ttstream.LongConnConfig{MaxIdleTimeout: time.Millisecond * 100})) + sp := ttstream.NewServerProvider() + cp := ttstream.NewClientProvider(ttstream.WithClientLongConnPool(ttstream.LongConnConfig{MaxIdleTimeout: time.Millisecond * 100})) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_LongConn", ClientProvider: cp, ServerProvider: sp}) - cp, _ = ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientShortConnPool()) + cp = ttstream.NewClientProvider(ttstream.WithClientShortConnPool()) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_ShortConn", ClientProvider: cp, ServerProvider: sp}) - cp, _ = ttstream.NewClientProvider(testServiceInfo, ttstream.WithClientMuxConnPool(ttstream.MuxConnConfig{PoolSize: 8, MaxIdleTimeout: time.Millisecond * 100})) + cp = ttstream.NewClientProvider(ttstream.WithClientMuxConnPool(ttstream.MuxConnConfig{PoolSize: 8, MaxIdleTimeout: time.Millisecond * 100})) providerTestCases = append(providerTestCases, testCase{Name: "TTHeader_Mux", ClientProvider: cp, ServerProvider: sp}) } diff --git a/server/streamxserver/option.go b/server/streamxserver/option.go index b097c91a59..794180a438 100644 --- a/server/streamxserver/option.go +++ b/server/streamxserver/option.go @@ -27,10 +27,11 @@ import ( "github.com/cloudwego/kitex/server" ) +// Deprecated: Note that it maybe refactor in the next version func WithProvider(provider streamx.ServerProvider) server.Option { return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { - // streamx provider server trans handler can use with other protocol o.StreamX.Provider = provider + // streamx provider server trans handler can use with other protocol o.RemoteOpt.SvrHandlerFactory = detection.NewSvrTransHandlerFactory( netpoll.NewSvrTransHandlerFactory(), streamxstrans.NewSvrTransHandlerFactory(provider), @@ -39,12 +40,14 @@ func WithProvider(provider streamx.ServerProvider) server.Option { }} } +// WithStreamRecvMiddleware add recv middleware func WithStreamRecvMiddleware(mw streamx.StreamRecvMiddleware) server.Option { return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { o.StreamX.StreamRecvMiddlewares = append(o.StreamX.StreamRecvMiddlewares, mw) }} } +// WithStreamSendMiddleware add send middleware func WithStreamSendMiddleware(mw streamx.StreamSendMiddleware) server.Option { return server.Option{F: func(o *internal_server.Options, di *utils.Slice) { o.StreamX.StreamSendMiddlewares = append(o.StreamX.StreamSendMiddlewares, mw) From 262ad8591969f73d08bcecbfd3b6e40fcf1d1b0c Mon Sep 17 00:00:00 2001 From: Zhuowei Wang Date: Thu, 28 Nov 2024 14:48:46 +0800 Subject: [PATCH 33/34] chore: using gopool --- pkg/remote/trans/streamx/server_handler.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pkg/remote/trans/streamx/server_handler.go b/pkg/remote/trans/streamx/server_handler.go index ef35978d44..3c5a683cdd 100644 --- a/pkg/remote/trans/streamx/server_handler.go +++ b/pkg/remote/trans/streamx/server_handler.go @@ -24,11 +24,10 @@ import ( "net" "runtime/debug" "sync" - "time" istreamx "github.com/cloudwego/kitex/internal/streamx" - "github.com/cloudwego/kitex/internal/wpool" "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -45,8 +44,6 @@ import ( Other interface is used by trans pipeline */ -var streamWorkerPool = wpool.New(128, time.Second) - type svrTransHandlerFactory struct { provider streamx.ServerProvider } @@ -120,7 +117,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) } wg.Add(1) // stream level goroutine - streamWorkerPool.GoCtx(nctx, func() { + gofunc.GoFunc(nctx, func() { defer wg.Done() err := t.OnStream(nctx, conn, ss) if err != nil && !errors.Is(err, io.EOF) { From b11fc2c20964eb2f574e40ce1c2e565c0171ffa7 Mon Sep 17 00:00:00 2001 From: Scout Wang Date: Thu, 28 Nov 2024 15:07:55 +0800 Subject: [PATCH 34/34] chore: unexport detailed err type (#1629) --- .../ttstream/client_trans_pool_muxconn.go | 3 +- .../ttstream/client_trans_pool_shortconn.go | 4 +- pkg/streamx/provider/ttstream/exception.go | 53 ++++++++++++++++ .../terrors_test.go => exception_test.go} | 26 ++++---- pkg/streamx/provider/ttstream/frame.go | 13 ++-- pkg/streamx/provider/ttstream/stream.go | 11 ++-- .../provider/ttstream/terrors/terrors.go | 60 ------------------- .../provider/ttstream/transport_test.go | 3 +- pkg/streamx/streamx_common_test.go | 4 +- 9 files changed, 84 insertions(+), 93 deletions(-) rename pkg/streamx/provider/ttstream/{terrors/terrors_test.go => exception_test.go} (60%) delete mode 100644 pkg/streamx/provider/ttstream/terrors/terrors.go diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go index 58dcbd6efa..fc0d50b3d3 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_muxconn.go @@ -27,7 +27,6 @@ import ( "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/pkg/gofunc" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) var DefaultMuxConnConfig = MuxConnConfig{ @@ -105,7 +104,7 @@ func (tl *muxConnTransList) Get(network, addr string) (*transport, error) { trans = newTransport(clientTransport, conn, tl.pool) _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { // peer close - _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("connection closed by peer"))) + _ = trans.Close(errTransport.WithCause(errors.New("connection closed by peer"))) return nil }) runtime.SetFinalizer(trans, func(trans *transport) { diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go index 5e833885b8..74d839313d 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_shortconn.go @@ -19,8 +19,6 @@ package ttstream import ( "errors" "time" - - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) func newShortConnTransPool() transPool { @@ -41,5 +39,5 @@ func (p *shortConnTransPool) Get(network, addr string) (*transport, error) { } func (p *shortConnTransPool) Put(trans *transport) { - _ = trans.Close(terrors.ErrTransport.WithCause(errors.New("short connection closed"))) + _ = trans.Close(errTransport.WithCause(errors.New("short connection closed"))) } diff --git a/pkg/streamx/provider/ttstream/exception.go b/pkg/streamx/provider/ttstream/exception.go index 17f51b157d..59adfb274d 100644 --- a/pkg/streamx/provider/ttstream/exception.go +++ b/pkg/streamx/provider/ttstream/exception.go @@ -16,7 +16,60 @@ package ttstream +import ( + "errors" + + "github.com/cloudwego/kitex/pkg/kerrors" +) + type tException interface { Error() string TypeId() int32 } + +var ( + errApplicationException = newExceptionType("application exception", nil, 12001) + errUnexpectedHeader = newExceptionType("unexpected header frame", kerrors.ErrStreamingProtocol, 12002) + errIllegalBizErr = newExceptionType("illegal bizErr", kerrors.ErrStreamingProtocol, 12003) + errIllegalFrame = newExceptionType("illegal frame", kerrors.ErrStreamingProtocol, 12004) + errIllegalOperation = newExceptionType("illegal operation", kerrors.ErrStreamingProtocol, 12005) + errTransport = newExceptionType("transport is closing", kerrors.ErrStreamingProtocol, 12006) +) + +type exceptionType struct { + message string + // parent exceptionType + basic error + // detailed err + cause error + typeId int32 +} + +func newExceptionType(message string, parent error, typeId int32) *exceptionType { + return &exceptionType{message: message, basic: parent, typeId: typeId} +} + +func (e *exceptionType) WithCause(err error) error { + return &exceptionType{basic: e, cause: err} +} + +func (e *exceptionType) Error() string { + if e.cause == nil { + return e.message + } + return "[" + e.basic.Error() + "] " + e.cause.Error() +} + +func (e *exceptionType) Is(target error) bool { + return target == e || errors.Is(e.basic, target) || errors.Is(e.cause, target) +} + +// TypeId is used for aligning with ApplicationException interface +func (e *exceptionType) TypeId() int32 { + return e.typeId +} + +// Code is used for uniform code retrieving by Kitex in the future +func (e *exceptionType) Code() int32 { + return e.typeId +} diff --git a/pkg/streamx/provider/ttstream/terrors/terrors_test.go b/pkg/streamx/provider/ttstream/exception_test.go similarity index 60% rename from pkg/streamx/provider/ttstream/terrors/terrors_test.go rename to pkg/streamx/provider/ttstream/exception_test.go index 191e038a32..14b18463d9 100644 --- a/pkg/streamx/provider/ttstream/terrors/terrors_test.go +++ b/pkg/streamx/provider/ttstream/exception_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package terrors +package ttstream import ( "errors" @@ -28,23 +28,29 @@ import ( func TestErrors(t *testing.T) { causeErr := fmt.Errorf("test1") - newErr := ErrIllegalFrame.WithCause(causeErr) - test.Assert(t, errors.Is(newErr, ErrIllegalFrame), newErr) + newErr := errIllegalFrame.WithCause(causeErr) + test.Assert(t, errors.Is(newErr, errIllegalFrame), newErr) test.Assert(t, errors.Is(newErr, kerrors.ErrStreamingProtocol), newErr) - test.Assert(t, strings.Contains(newErr.Error(), ErrIllegalFrame.Error())) + test.Assert(t, strings.Contains(newErr.Error(), errIllegalFrame.Error())) test.Assert(t, strings.Contains(newErr.Error(), causeErr.Error())) + + appErr := errApplicationException.WithCause(causeErr) + test.Assert(t, errors.Is(appErr, errApplicationException), appErr) + test.Assert(t, !errors.Is(appErr, kerrors.ErrStreamingProtocol), appErr) + test.Assert(t, strings.Contains(appErr.Error(), errApplicationException.Error())) + test.Assert(t, strings.Contains(appErr.Error(), causeErr.Error())) } func TestCommonParentKerror(t *testing.T) { errs := []error{ - ErrUnexpectedHeader, - ErrApplicationException, - ErrIllegalBizErr, - ErrIllegalFrame, - ErrIllegalOperation, - ErrTransport, + errUnexpectedHeader, + errIllegalBizErr, + errIllegalFrame, + errIllegalOperation, + errTransport, } for _, err := range errs { test.Assert(t, errors.Is(err, kerrors.ErrStreamingProtocol), err) } + test.Assert(t, !errors.Is(errApplicationException, kerrors.ErrStreamingProtocol)) } diff --git a/pkg/streamx/provider/ttstream/frame.go b/pkg/streamx/provider/ttstream/frame.go index a0f5fdcfdc..383dd4b6eb 100644 --- a/pkg/streamx/provider/ttstream/frame.go +++ b/pkg/streamx/provider/ttstream/frame.go @@ -32,7 +32,6 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) const ( @@ -107,7 +106,7 @@ func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr *Frame) (err erro totalLenField, err := ttheader.Encode(ctx, param, writer) if err != nil { - return terrors.ErrIllegalFrame.WithCause(err) + return errIllegalFrame.WithCause(err) } if len(fr.payload) > 0 { if nw, ok := writer.(gopkgthrift.NocopyWriter); ok { @@ -116,7 +115,7 @@ func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr *Frame) (err erro _, err = writer.WriteBinary(fr.payload) } if err != nil { - return terrors.ErrTransport.WithCause(err) + return errTransport.WithCause(err) } } written = writer.WrittenLen() - written @@ -131,10 +130,10 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro if errors.Is(err, io.EOF) { return nil, err } - return nil, terrors.ErrIllegalFrame.WithCause(err) + return nil, errIllegalFrame.WithCause(err) } if dp.Flags&ttheader.HeaderFlagsStreaming == 0 { - return nil, terrors.ErrIllegalFrame.WithCause(fmt.Errorf("unexpected header flags: %d", dp.Flags)) + return nil, errIllegalFrame.WithCause(fmt.Errorf("unexpected header flags: %d", dp.Flags)) } var ftype int32 @@ -153,7 +152,7 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro ftype = trailerFrameType ftrailer = dp.StrInfo default: - return nil, terrors.ErrIllegalFrame.WithCause(fmt.Errorf("unexpected frame type: %v", dp.IntInfo[ttheader.FrameType])) + return nil, errIllegalFrame.WithCause(fmt.Errorf("unexpected frame type: %v", dp.IntInfo[ttheader.FrameType])) } fmethod := dp.IntInfo[ttheader.ToMethod] fsid := dp.SeqID @@ -165,7 +164,7 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro _, err = reader.ReadBinary(fpayload) // copy read _ = reader.Release(err) if err != nil { - return nil, terrors.ErrTransport.WithCause(err) + return nil, errTransport.WithCause(err) } } else { _ = reader.Release(nil) diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index 30f6e5ab7a..42760a1bf8 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -32,7 +32,6 @@ import ( "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" "github.com/cloudwego/kitex/pkg/transmeta" ) @@ -111,7 +110,7 @@ func (s *stream) Method() string { func (s *stream) SendMsg(ctx context.Context, msg any) (err error) { if atomic.LoadInt32(&s.selfEOF) != 0 { - return terrors.ErrIllegalOperation.WithCause(errors.New("stream is closed send")) + return errIllegalOperation.WithCause(errors.New("stream is closed send")) } // encode payload payload, err := EncodePayload(ctx, msg) @@ -303,13 +302,13 @@ func (s *stream) onReadMetaFrame(fr *Frame) (err error) { func (s *stream) onReadHeaderFrame(fr *Frame) (err error) { if s.header != nil { - return terrors.ErrUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) + return errUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) } s.header = fr.header select { case s.headerSig <- streamSigActive: default: - return terrors.ErrUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) + return errUnexpectedHeader.WithCause(fmt.Errorf("stream[%d] already set header", s.sid)) } klog.Debugf("stream[%s] read header: %v", s.method, fr.header) return nil @@ -328,12 +327,12 @@ func (s *stream) onReadTrailerFrame(fr *Frame) (err error) { if len(fr.payload) > 0 { // exception is type of (*thrift.ApplicationException) _, _, err = thrift.UnmarshalFastMsg(fr.payload, nil) - exception = terrors.ErrApplicationException.WithCause(err) + exception = errApplicationException.WithCause(err) } else if len(fr.trailer) > 0 { // when server-side returns biz error, payload is empty and biz error information is stored in trailer frame header bizErr, err := transmeta.ParseBizStatusErr(fr.trailer) if err != nil { - exception = terrors.ErrIllegalBizErr.WithCause(err) + exception = errIllegalBizErr.WithCause(err) } else if bizErr != nil { // bizErr is independent of rpc exception handling exception = bizErr diff --git a/pkg/streamx/provider/ttstream/terrors/terrors.go b/pkg/streamx/provider/ttstream/terrors/terrors.go deleted file mode 100644 index 769ffc2287..0000000000 --- a/pkg/streamx/provider/ttstream/terrors/terrors.go +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package terrors - -import ( - "errors" - - "github.com/cloudwego/kitex/pkg/kerrors" -) - -// terrors define TTHeader Streaming-related protocol errors, they all inherit from ErrStreamingProtocol in kerrors. -var ( - ErrUnexpectedHeader = newErrType("unexpected header frame") - ErrApplicationException = newErrType("application exception") - ErrIllegalBizErr = newErrType("illegal bizErr") - ErrIllegalFrame = newErrType("illegal frame") - ErrIllegalOperation = newErrType("illegal operation") - ErrTransport = newErrType("transport is closing") -) - -type errType struct { - message string - // parent errType - basic error - // detailed err - cause error -} - -func newErrType(message string) *errType { - return &errType{message: message, basic: kerrors.ErrStreamingProtocol} -} - -func (e *errType) WithCause(err error) error { - return &errType{basic: e, cause: err} -} - -func (e *errType) Error() string { - if e.cause == nil { - return e.message - } - return "[" + e.basic.Error() + "] " + e.cause.Error() -} - -func (e *errType) Is(target error) bool { - return target == e || errors.Is(e.basic, target) || errors.Is(e.cause, target) -} diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index a15ef6e5f2..556cd6dd8b 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -36,7 +36,6 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streamx" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" "github.com/cloudwego/kitex/server/streamxserver" ) @@ -246,7 +245,7 @@ func TestTransportException(t *testing.T) { cStream = newClientStream(rawClientStream) err = cStream.RecvMsg(context.Background(), res) test.Assert(t, err != nil, err) - test.Assert(t, errors.Is(err, terrors.ErrIllegalFrame), err) + test.Assert(t, errors.Is(err, errIllegalFrame), err) } func TestStreamID(t *testing.T) { diff --git a/pkg/streamx/streamx_common_test.go b/pkg/streamx/streamx_common_test.go index 30091a9cc9..4f9a5da1e4 100644 --- a/pkg/streamx/streamx_common_test.go +++ b/pkg/streamx/streamx_common_test.go @@ -25,7 +25,6 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" - "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/terrors" ) const ( @@ -80,8 +79,7 @@ func validateMetadata(ctx context.Context) bool { } func assertNormalErr(t *testing.T, err error) { - test.Assert(t, errors.Is(err, terrors.ErrApplicationException), err) - test.Assert(t, errors.Is(err, kerrors.ErrStreamingProtocol), err) + test.Assert(t, !errors.Is(err, kerrors.ErrStreamingProtocol), err) } func assertBizErr(t *testing.T, err error) {