Skip to content

Commit

Permalink
feat: server middleware pass through
Browse files Browse the repository at this point in the history
  • Loading branch information
felix021 committed Feb 28, 2024
1 parent 4e44114 commit e8f04b8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
2 changes: 2 additions & 0 deletions internal/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ type Options struct {

RefuseTrafficWithoutServiceName bool
EnableContextTimeout bool

EnableStreamingContextPassThrough bool
}

type Limit struct {
Expand Down
9 changes: 9 additions & 0 deletions server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,15 @@ func WithRefuseTrafficWithoutServiceName() Option {
// so there's no need to use this option.
func WithEnableContextTimeout(enable bool) Option {
return Option{F: func(o *internal_server.Options, di *utils.Slice) {
di.Push(fmt.Sprintf("WithEnableContextTimeout({%+v})", enable))
o.EnableContextTimeout = enable
}}
}

// WithEnableStreamingContextPassThrough enables passing through context modified by server middleware to handler.
func WithEnableStreamingContextPassThrough() Option {
return Option{F: func(o *internal_server.Options, di *utils.Slice) {
di.Push(fmt.Sprintf("EnableStreamingContextPassThrough()"))

Check failure on line 397 in server/option.go

View workflow job for this annotation

GitHub Actions / lint

S1039: unnecessary use of fmt.Sprintf (gosimple)
o.EnableStreamingContextPassThrough = true
}}

Check warning on line 399 in server/option.go

View check run for this annotation

Codecov / codecov/patch

server/option.go#L395-L399

Added lines #L395 - L399 were not covered by tests
}
23 changes: 22 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/stats"
"github.com/cloudwego/kitex/pkg/streaming"
)

// Server is an abstraction of an RPC server. It accepts connections and dispatches them to the service
Expand Down Expand Up @@ -337,7 +338,16 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint {
// clear session
backup.ClearCtx()
}()
implHandlerFunc := svcInfo.MethodInfo(methodName).Handler()
methodInfo := svcInfo.MethodInfo(methodName)
if methodInfo.IsStreaming() && s.opt.EnableStreamingContextPassThrough {
if sArg, ok := args.(*streaming.Args); ok {
sArg.Stream = &contextStream{
Stream: sArg.Stream,
ctx: ctx,

Check warning on line 346 in server/server.go

View check run for this annotation

Codecov / codecov/patch

server/server.go#L343-L346

Added lines #L343 - L346 were not covered by tests
}
}
}
implHandlerFunc := methodInfo.Handler()
rpcinfo.Record(ctx, ri, stats.ServerHandleStart, nil)
// set session
backup.BackupCtx(ctx)
Expand All @@ -355,6 +365,17 @@ func (s *server) invokeHandleEndpoint() endpoint.Endpoint {
}
}

// contextStream is used to return the ctx node passed in by server middlewares which might contain new key/values
type contextStream struct {
streaming.Stream
ctx context.Context
}

// Context implements streaming.Stream
func (s *contextStream) Context() context.Context {
return s.ctx

Check warning on line 376 in server/server.go

View check run for this annotation

Codecov / codecov/patch

server/server.go#L375-L376

Added lines #L375 - L376 were not covered by tests
}

func (s *server) initBasicRemoteOption() {
remoteOpt := s.opt.RemoteOpt
remoteOpt.TargetSvcInfo = s.targetSvcInfo
Expand Down

0 comments on commit e8f04b8

Please sign in to comment.