From e8f04b84a51782e838ce732ed29962057de32290 Mon Sep 17 00:00:00 2001 From: "felix.fengmin" Date: Wed, 28 Feb 2024 20:56:20 +0800 Subject: [PATCH] feat: server middleware pass through --- internal/server/option.go | 2 ++ server/option.go | 9 +++++++++ server/server.go | 23 ++++++++++++++++++++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/internal/server/option.go b/internal/server/option.go index b540f03ce4..ece1d9408e 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -100,6 +100,8 @@ type Options struct { RefuseTrafficWithoutServiceName bool EnableContextTimeout bool + + EnableStreamingContextPassThrough bool } type Limit struct { diff --git a/server/option.go b/server/option.go index 59e587d7cd..9dd4c4cc68 100644 --- a/server/option.go +++ b/server/option.go @@ -386,6 +386,15 @@ func WithRefuseTrafficWithoutServiceName() Option { // so there's no need to use this option. func WithEnableContextTimeout(enable bool) Option { return Option{F: func(o *internal_server.Options, di *utils.Slice) { + di.Push(fmt.Sprintf("WithEnableContextTimeout({%+v})", enable)) o.EnableContextTimeout = enable }} } + +// WithEnableStreamingContextPassThrough enables passing through context modified by server middleware to handler. +func WithEnableStreamingContextPassThrough() Option { + return Option{F: func(o *internal_server.Options, di *utils.Slice) { + di.Push(fmt.Sprintf("EnableStreamingContextPassThrough()")) + o.EnableStreamingContextPassThrough = true + }} +} diff --git a/server/server.go b/server/server.go index f2d6e135e2..b1cca3381b 100644 --- a/server/server.go +++ b/server/server.go @@ -45,6 +45,7 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" + "github.com/cloudwego/kitex/pkg/streaming" ) // Server is an abstraction of an RPC server. It accepts connections and dispatches them to the service @@ -337,7 +338,16 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint { // clear session backup.ClearCtx() }() - implHandlerFunc := svcInfo.MethodInfo(methodName).Handler() + methodInfo := svcInfo.MethodInfo(methodName) + if methodInfo.IsStreaming() && s.opt.EnableStreamingContextPassThrough { + if sArg, ok := args.(*streaming.Args); ok { + sArg.Stream = &contextStream{ + Stream: sArg.Stream, + ctx: ctx, + } + } + } + implHandlerFunc := methodInfo.Handler() rpcinfo.Record(ctx, ri, stats.ServerHandleStart, nil) // set session backup.BackupCtx(ctx) @@ -355,6 +365,17 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint { } } +// contextStream is used to return the ctx node passed in by server middlewares which might contain new key/values +type contextStream struct { + streaming.Stream + ctx context.Context +} + +// Context implements streaming.Stream +func (s *contextStream) Context() context.Context { + return s.ctx +} + func (s *server) initBasicRemoteOption() { remoteOpt := s.opt.RemoteOpt remoteOpt.TargetSvcInfo = s.targetSvcInfo