From b2492959ca424e3ba6e28e659dab6c0b453059ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Jochum?= Date: Sun, 29 Sep 2024 19:01:26 +0200 Subject: [PATCH] feat(client/orb,server/*): Implement metadata transfer and server middleware. Fixes #15. --- .golangci.yaml | 3 + client/orb/client.go | 54 ++- client/orb/config.go | 2 +- client/orb/transport/basehertz/basehertz.go | 28 +- client/orb/transport/basehttp/basehttp.go | 43 +- client/orb/transport/drpc/drpc.go | 26 +- client/orb/transport/grpc/go.mod | 2 +- client/orb/transport/grpc/grpc.go | 24 +- client/tests/cmd/tests_server/wire_gen.go | 6 +- client/tests/go.mod | 3 +- client/tests/go.sum | 6 +- client/tests/handler/handler.go | 23 +- client/tests/proto/echo.pb.go | 50 +-- client/tests/proto/echo.proto | 10 +- client/tests/proto/echo_grpc.pb.go | 105 ----- .../{echo_drpc.pb.go => echo_orb-drpc.pb.go} | 77 ++-- client/tests/proto/echo_orb-grpc.pb.go | 121 ++++++ client/tests/proto/echo_orb.pb.go | 74 ++-- client/tests/proto/gen.go | 6 +- .../tests/proto/google/api/annotations.proto | 31 -- client/tests/proto/google/api/http.proto | 379 ------------------ client/tests/tests.go | 45 ++- .../drpc/drpc_test.go | 0 .../grpc/grpc_test.go | 0 .../h2c/h2c_test.go | 0 .../hertzh2c/hertzh2c_test.go | 0 .../hertzhttp/hertzhttp_test.go | 0 .../http/http_test.go | 0 .../http3/http3_test.go | 0 .../https/https_test.go | 0 event/natsjs/natsjs.go | 6 +- server/cmd/protoc-gen-go-orb/orb/template.go | 22 +- server/cmd/protoc-gen-go-orb/orbdrpc/drpc.go | 5 - server/drpc/config.go | 25 +- server/drpc/drpc.go | 65 ++- server/drpc/go.mod | 2 + server/drpc/go.sum | 6 + server/drpc/handle_rpc.go | 112 ++++++ server/drpc/message/message.pb.go | 165 ++++++++ server/drpc/message/message.proto | 12 + server/drpc/mux.go | 98 +++++ server/drpc/plugin.go | 2 +- server/grpc/error.go | 40 ++ server/grpc/grpc.go | 4 +- server/grpc/interceptor.go | 65 ++- server/grpc/plugin.go | 2 +- server/grpc/tests/grpc_test.go | 12 +- server/grpc/tests/util/grpc/grpc.go | 4 +- server/hertz/config.go | 25 +- server/hertz/handler.go | 41 +- server/hertz/hertz.go | 25 +- server/hertz/plugin.go | 2 +- server/http/config.go | 23 +- server/http/entrypoint.go | 31 +- server/http/handler.go | 46 ++- server/http/plugin.go | 2 +- server/http/tests/http_test.go | 20 +- server/http/tests/proto/echo_http.micro.pb.go | 4 +- server/tests/server_test.go | 14 +- 59 files changed, 1156 insertions(+), 842 deletions(-) delete mode 100644 client/tests/proto/echo_grpc.pb.go rename client/tests/proto/{echo_drpc.pb.go => echo_orb-drpc.pb.go} (62%) create mode 100644 client/tests/proto/echo_orb-grpc.pb.go delete mode 100644 client/tests/proto/google/api/annotations.proto delete mode 100644 client/tests/proto/google/api/http.proto rename client/tests/{orb_transport => transport}/drpc/drpc_test.go (100%) rename client/tests/{orb_transport => transport}/grpc/grpc_test.go (100%) rename client/tests/{orb_transport => transport}/h2c/h2c_test.go (100%) rename client/tests/{orb_transport => transport}/hertzh2c/hertzh2c_test.go (100%) rename client/tests/{orb_transport => transport}/hertzhttp/hertzhttp_test.go (100%) rename client/tests/{orb_transport => transport}/http/http_test.go (100%) rename client/tests/{orb_transport => transport}/http3/http3_test.go (100%) rename client/tests/{orb_transport => transport}/https/https_test.go (100%) create mode 100644 server/drpc/handle_rpc.go create mode 100644 server/drpc/message/message.pb.go create mode 100644 server/drpc/message/message.proto create mode 100644 server/drpc/mux.go create mode 100644 server/grpc/error.go diff --git a/.golangci.yaml b/.golangci.yaml index eaf671f7..ff2b2602 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -213,6 +213,9 @@ linters: # Deprecated - exportloopref + # Doesn't work for us + - contextcheck + issues: # List of regexps of issue texts to exclude, empty list by default. # But independently from this option we use default exclude patterns, diff --git a/client/orb/client.go b/client/orb/client.go index 7b7e2eb6..86a2685f 100644 --- a/client/orb/client.go +++ b/client/orb/client.go @@ -211,25 +211,24 @@ func (c *Client) Call( ) (resp *client.RawResponse, err error) { co := c.makeOptions(opts...) - // Wrap middlewares - call := c.call - for _, m := range c.middlewares { - call = m.Call(call) - } - - return call(ctx, req, co) -} + // Add metadata to the context. + ctx = metadata.EnsureIncoming(ctx) + ctx = metadata.EnsureOutgoing(ctx) -func (c *Client) call(ctx context.Context, req *client.Request[any, any], opts *client.CallOptions) (resp *client.RawResponse, err error) { - transport, err := c.transportForReq(ctx, req, opts) + transport, err := c.transportForReq(ctx, req, co) if err != nil { return nil, err } - // Add metadata to the context. - ctx = metadata.Ensure(ctx) + // Wrap middlewares + call := func(ctx context.Context, req *client.Request[any, any], opts *client.CallOptions) (*client.RawResponse, error) { + return transport.Call(ctx, req, opts) + } + for _, m := range c.middlewares { + call = m.Call(call) + } - return transport.Call(ctx, req, opts) + return call(ctx, req, co) } // CallNoCodec does the actual call without codecs. @@ -241,25 +240,24 @@ func (c *Client) CallNoCodec( ) error { co := c.makeOptions(opts...) - // Wrap middlewares - call := c.callNoCodec - for _, m := range c.middlewares { - call = m.CallNoCodec(call) - } - // Add metadata to the context. - ctx = metadata.Ensure(ctx) - - return call(ctx, req, result, co) -} + ctx = metadata.EnsureIncoming(ctx) + ctx = metadata.EnsureOutgoing(ctx) -func (c *Client) callNoCodec(ctx context.Context, req *client.Request[any, any], result any, opts *client.CallOptions) (err error) { - transport, err := c.transportForReq(ctx, req, opts) + transport, err := c.transportForReq(ctx, req, co) if err != nil { return err } - return transport.CallNoCodec(ctx, req, result, opts) + // Wrap middlewares + call := func(ctx context.Context, req *client.Request[any, any], result any, opts *client.CallOptions) error { + return transport.CallNoCodec(ctx, req, result, opts) + } + for _, m := range c.middlewares { + call = m.CallNoCodec(call) + } + + return call(ctx, req, result, co) } // New creates a new orb client. This functions should rarely be called manually. @@ -293,8 +291,8 @@ func New(cfg Config, log log.Logger, registry registry.Type) *Client { } } -// ProvideClient is the wire provider for client. -func ProvideClient( +// Provide is the wire provider for client. +func Provide( name types.ServiceName, data types.ConfigData, logger log.Logger, diff --git a/client/orb/config.go b/client/orb/config.go index d1ddf5bf..58aafa24 100644 --- a/client/orb/config.go +++ b/client/orb/config.go @@ -12,7 +12,7 @@ import ( const Name = "orb" func init() { - client.Register(Name, ProvideClient) + client.Register(Name, Provide) } // Config is the config for the orb client. diff --git a/client/orb/transport/basehertz/basehertz.go b/client/orb/transport/basehertz/basehertz.go index f5550dcb..81cd5f1e 100644 --- a/client/orb/transport/basehertz/basehertz.go +++ b/client/orb/transport/basehertz/basehertz.go @@ -5,7 +5,7 @@ import ( "bytes" "context" "fmt" - "strings" + "slices" hclient "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/protocol" @@ -19,8 +19,8 @@ import ( "github.com/go-orb/plugins/client/orb" ) -// orbHeader is the prefix for every orb HTTP header. -const orbHeader = "__orb-" +//nolint:gochecknoglobals +var stdHeaders = []string{"Content-Length", "Content-Type", "Date", "Server"} var _ (orb.Transport) = (*Transport)(nil) @@ -87,10 +87,10 @@ func (t *Transport) Call(ctx context.Context, req *client.Request[any, any], opt hReq.SetRequestURI(fmt.Sprintf("%s://%s/%s", t.scheme, node.Address, req.Endpoint())) // Set metadata key=value to request headers. - md, ok := metadata.From(ctx) + md, ok := metadata.OutgoingFrom(ctx) if ok { for name, value := range md { - hReq.Header.Set(orbHeader+name, value) + hReq.Header.Set(name, value) } } @@ -139,21 +139,23 @@ func (t *Transport) call2( res := &client.RawResponse{ ContentType: hRes.Header.Get("Content-Type"), Body: &hresBodyCloserWrapper{buff: buff}, - Metadata: make(metadata.Metadata), } if hRes.StatusCode() != consts.StatusOK { return res, orberrors.NewHTTP(hRes.StatusCode()) } - // Copy headers to the RawResponse. - for _, v := range hRes.Header.GetHeaders() { - k := string(v.GetKey()) - if !strings.HasPrefix(strings.ToLower(k), orbHeader) { - continue - } + if opts.Headers != nil { + for _, v := range hRes.Header.GetHeaders() { + k := string(v.GetKey()) + + // Skip std headers. + if slices.Contains(stdHeaders, k) { + continue + } - res.Metadata[k[len(orbHeader):]] = string(v.GetValue()) + opts.Headers[k] = string(v.GetValue()) + } } return res, nil diff --git a/client/orb/transport/basehttp/basehttp.go b/client/orb/transport/basehttp/basehttp.go index a44c649d..896cd850 100644 --- a/client/orb/transport/basehttp/basehttp.go +++ b/client/orb/transport/basehttp/basehttp.go @@ -9,8 +9,8 @@ import ( "fmt" "io" "net/http" + "slices" "strconv" - "strings" "github.com/go-orb/go-orb/client" "github.com/go-orb/go-orb/codecs" @@ -20,8 +20,8 @@ import ( "github.com/go-orb/plugins/client/orb" ) -// orbHeader is the prefix for every orb HTTP header. -const orbHeader = "__orb-" +//nolint:gochecknoglobals +var stdHeaders = []string{"Content-Length", "Content-Type", "Date", "Server"} var _ (orb.Transport) = (*Transport)(nil) @@ -95,17 +95,17 @@ func (t *Transport) Call(ctx context.Context, req *client.Request[any, any], opt hReq.Header.Set("Accept", opts.ContentType) // Set metadata key=value to request headers. - md, ok := metadata.From(ctx) + md, ok := metadata.OutgoingFrom(ctx) if ok { for name, value := range md { - hReq.Header.Set(orbHeader+name, value) + hReq.Header.Set(name, value) } } - return t.call2(hReq) + return t.call2(hReq, opts) } -func (t *Transport) call2(hReq *http.Request) (*client.RawResponse, error) { +func (t *Transport) call2(hReq *http.Request, opts *client.CallOptions) (*client.RawResponse, error) { // Run the request. resp, err := t.hclient.Do(hReq) if err != nil { @@ -128,23 +128,26 @@ func (t *Transport) call2(hReq *http.Request) (*client.RawResponse, error) { res := &client.RawResponse{ ContentType: resp.Header.Get("Content-Type"), Body: buff, - Metadata: make(metadata.Metadata), } - // Copy headers to the RawResponse. - for k, v := range resp.Header { - if !strings.HasPrefix(strings.ToLower(k), orbHeader) { - continue - } + if opts.Headers != nil { + md := opts.Headers + + // Copy headers to opts.Header + for k, v := range resp.Header { + // Skip std headers. + if slices.Contains(stdHeaders, k) { + continue + } - k = k[len(orbHeader):] + if len(v) == 1 { + md[k] = v[0] + } else { + md[k] = v[0] - if len(v) == 1 { - res.Metadata[k] = v[0] - } else { - res.Metadata[k] = v[0] - for i := 1; i < len(v); i++ { - res.Metadata[k+"-"+strconv.Itoa(i)] = v[i] + for i := 1; i < len(v); i++ { + md[k+"-"+strconv.Itoa(i)] = v[i] + } } } } diff --git a/client/orb/transport/drpc/drpc.go b/client/orb/transport/drpc/drpc.go index f7cc5412..c3b6cc66 100644 --- a/client/orb/transport/drpc/drpc.go +++ b/client/orb/transport/drpc/drpc.go @@ -9,14 +9,18 @@ import ( "storj.io/drpc" "storj.io/drpc/drpcconn" + "storj.io/drpc/drpcerr" + "storj.io/drpc/drpcmetadata" "storj.io/drpc/drpcpool" "google.golang.org/protobuf/proto" "github.com/go-orb/go-orb/client" "github.com/go-orb/go-orb/log" + "github.com/go-orb/go-orb/util/metadata" "github.com/go-orb/go-orb/util/orberrors" "github.com/go-orb/plugins/client/orb" + "github.com/go-orb/plugins/server/drpc/message" ) var _ drpc.Encoding = (*encoder)(nil) @@ -98,14 +102,32 @@ func (t *Transport) CallNoCodec(ctx context.Context, req *client.Request[any, an conn := t.pool.Get(ctx, node.Address, dial) + // Add metadata to drpc. + md, ok := metadata.OutgoingFrom(ctx) + if ok { + ctx = drpcmetadata.AddPairs(ctx, md) + } + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(opts.RequestTimeout)) defer cancel() - err = conn.Invoke(ctx, "/"+req.Endpoint(), &t.encoder, req.Request(), result) - if err != nil { + mdResult := &message.Response{} + if err := conn.Invoke(ctx, "/"+req.Endpoint(), &t.encoder, req.Request(), mdResult); err != nil { + return orberrors.New(int(drpcerr.Code(err)), err.Error()) //nolint:gosec + } + + // Unmarshal the result. + if err := mdResult.GetData().UnmarshalTo(result.(proto.Message)); err != nil { return orberrors.From(err) } + // Retrieve metadata from drpc. + if opts.Headers != nil { + for k, v := range mdResult.GetMetadata() { + opts.Headers[k] = v + } + } + err = conn.Close() if err != nil { return orberrors.From(err) diff --git a/client/orb/transport/grpc/go.mod b/client/orb/transport/grpc/go.mod index 5fb5c241..26bfbe17 100644 --- a/client/orb/transport/grpc/go.mod +++ b/client/orb/transport/grpc/go.mod @@ -18,4 +18,4 @@ require ( golang.org/x/text v0.17.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/protobuf v1.34.2 // indirect -) +) \ No newline at end of file diff --git a/client/orb/transport/grpc/grpc.go b/client/orb/transport/grpc/grpc.go index 4e74ba83..82b512da 100644 --- a/client/orb/transport/grpc/grpc.go +++ b/client/orb/transport/grpc/grpc.go @@ -8,10 +8,12 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + gmetadata "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "github.com/go-orb/go-orb/client" "github.com/go-orb/go-orb/log" + "github.com/go-orb/go-orb/util/metadata" "github.com/go-orb/go-orb/util/orberrors" "github.com/go-orb/plugins/client/orb" "github.com/go-orb/plugins/client/orb/transport/grpc/pool" @@ -59,6 +61,8 @@ func (t *Transport) Call(_ context.Context, _ *client.Request[any, any], _ *clie } // CallNoCodec does the actual rpc call to the server. +// +//nolint:funlen func (t *Transport) CallNoCodec(ctx context.Context, req *client.Request[any, any], result any, opts *client.CallOptions) error { node, err := req.Node(ctx, opts) if err != nil { @@ -92,10 +96,22 @@ func (t *Transport) CallNoCodec(ctx context.Context, req *client.Request[any, an return orberrors.From(err) } + // Append go-orb metadata to grpc. + if md, ok := metadata.OutgoingFrom(ctx); ok { + kv := []string{} + for k, v := range md { + kv = append(kv, k, v) + } + + ctx = gmetadata.AppendToOutgoingContext(ctx, kv...) + } + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(opts.RequestTimeout)) defer cancel() - err = conn.Invoke(ctx, "/"+req.Endpoint(), req.Request(), result) + resMeta := gmetadata.MD{} + + err = conn.Invoke(ctx, "/"+req.Endpoint(), req.Request(), result, grpc.Header(&resMeta)) if err != nil { gErr, ok := status.FromError(err) if !ok { @@ -111,6 +127,12 @@ func (t *Transport) CallNoCodec(ctx context.Context, req *client.Request[any, an return orberrors.New(httpStatusCode, gErr.Message()) } + if opts.Headers != nil { + for k, v := range resMeta { + opts.Headers[k] = v[0] + } + } + err = conn.Close() if err != nil { gErr, ok := status.FromError(err) diff --git a/client/tests/cmd/tests_server/wire_gen.go b/client/tests/cmd/tests_server/wire_gen.go index 30ffc5b8..1ddbf9bb 100644 --- a/client/tests/cmd/tests_server/wire_gen.go +++ b/client/tests/cmd/tests_server/wire_gen.go @@ -44,12 +44,12 @@ func newComponents(serviceName types.ServiceName, serviceVersion types.ServiceVe return nil, err } v := _wireValue - logger, err := log.ProvideLogger(serviceName, configData, v...) + logger, err := log.Provide(serviceName, configData, v...) if err != nil { return nil, err } v2 := _wireValue2 - registryType, err := registry.ProvideRegistry(serviceName, serviceVersion, configData, logger, v2...) + registryType, err := registry.Provide(serviceName, serviceVersion, configData, logger, v2...) if err != nil { return nil, err } @@ -57,7 +57,7 @@ func newComponents(serviceName types.ServiceName, serviceVersion types.ServiceVe if err != nil { return nil, err } - serverServer, err := server.ProvideServer(serviceName, configData, logger, registryType, v3...) + serverServer, err := server.Provide(serviceName, configData, logger, registryType, v3...) if err != nil { return nil, err } diff --git a/client/tests/go.mod b/client/tests/go.mod index 268ba61f..7e6c729b 100644 --- a/client/tests/go.mod +++ b/client/tests/go.mod @@ -24,12 +24,11 @@ require ( github.com/go-orb/plugins/server/drpc v0.0.0-20240925073122-2250c978d160 github.com/go-orb/plugins/server/grpc v0.0.0-20240925073122-2250c978d160 github.com/go-orb/plugins/server/hertz v0.0.0-20240925070424-371b8463d2d6 - github.com/go-orb/plugins/server/http v0.0.0-20240925070424-371b8463d2d6 + github.com/go-orb/plugins/server/http v0.0.0-20240925074005-17c0b37c3d6b github.com/google/wire v0.6.0 github.com/hashicorp/consul/sdk v0.16.1 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 - google.golang.org/genproto/googleapis/api v0.0.0-20240924160255-9d4c2d233b61 google.golang.org/grpc v1.67.0 google.golang.org/protobuf v1.34.2 storj.io/drpc v0.0.34 diff --git a/client/tests/go.sum b/client/tests/go.sum index c72091ab..d72f68b2 100644 --- a/client/tests/go.sum +++ b/client/tests/go.sum @@ -83,8 +83,8 @@ github.com/go-orb/plugins/server/grpc v0.0.0-20240925073122-2250c978d160 h1:JYdS github.com/go-orb/plugins/server/grpc v0.0.0-20240925073122-2250c978d160/go.mod h1:sLf3tGSGqbYm03XxF+6BKKLV2WJtUSrr6ja0IAVzmHs= github.com/go-orb/plugins/server/hertz v0.0.0-20240925070424-371b8463d2d6 h1:i3wNZnv9uvB0NfHOg9CnWi/jv2cNqLV5QstyZo6NyVU= github.com/go-orb/plugins/server/hertz v0.0.0-20240925070424-371b8463d2d6/go.mod h1:6LZZaZJ7dUMNgzwrOWfpo0xCo3xC0HpqSuQnK3XXUuw= -github.com/go-orb/plugins/server/http v0.0.0-20240925070424-371b8463d2d6 h1:nRhQaeJQ715AdVRsLTGKMa8QLRsHCk+QhRWafm2Oe0g= -github.com/go-orb/plugins/server/http v0.0.0-20240925070424-371b8463d2d6/go.mod h1:T4mUQ7uIMwc5M+GJHYphE8oqXNJerCCXckfhJlUJRls= +github.com/go-orb/plugins/server/http v0.0.0-20240925074005-17c0b37c3d6b h1:jU+dPzIwWhU1uphO6m5Cgz+WtcHZhfM4ULBasU6I0GA= +github.com/go-orb/plugins/server/http v0.0.0-20240925074005-17c0b37c3d6b/go.mod h1:S5Sk+NBYXwcvcnjVrp26RsuWxJB0k7qsdOdD/aDeGX8= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -253,8 +253,6 @@ golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20240924160255-9d4c2d233b61 h1:pAjq8XSSzXoP9ya73v/w+9QEAAJNluLrpmMq5qFJQNY= -google.golang.org/genproto/googleapis/api v0.0.0-20240924160255-9d4c2d233b61/go.mod h1:O6rP0uBq4k0mdi/b4ZEMAZjkhYWhS815kCvaMha4VN8= google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 h1:N9BgCIAUvn/M+p4NJccWPWb3BWh88+zyL0ll9HgbEeM= google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw= diff --git a/client/tests/handler/handler.go b/client/tests/handler/handler.go index 71544704..b5388879 100644 --- a/client/tests/handler/handler.go +++ b/client/tests/handler/handler.go @@ -6,6 +6,8 @@ import ( "crypto/rand" "errors" + "github.com/go-orb/go-orb/util/metadata" + "github.com/go-orb/go-orb/util/orberrors" "github.com/go-orb/plugins/client/tests/proto" ) @@ -17,8 +19,8 @@ type EchoHandler struct { } // Call implements the call method. -func (c *EchoHandler) Call(_ context.Context, in *proto.CallRequest) (*proto.CallResponse, error) { - switch in.GetName() { +func (c *EchoHandler) Call(_ context.Context, request *proto.CallRequest) (*proto.CallResponse, error) { + switch request.GetName() { case "error": return nil, errors.New("you asked for an error, here you go") case "32byte": @@ -35,8 +37,21 @@ func (c *EchoHandler) Call(_ context.Context, in *proto.CallRequest) (*proto.Cal return nil, err } - return &proto.CallResponse{Msg: "Hello " + in.GetName(), Payload: msg}, nil + return &proto.CallResponse{Msg: "Hello " + request.GetName(), Payload: msg}, nil default: - return &proto.CallResponse{Msg: "Hello " + in.GetName()}, nil + return &proto.CallResponse{Msg: "Hello " + request.GetName()}, nil } } + +// AuthorizedCall requires Authorization by metadata. +func (c *EchoHandler) AuthorizedCall(ctx context.Context, _ *proto.CallRequest) (*proto.CallResponse, error) { + mdout, _ := metadata.OutgoingFrom(ctx) + mdout["tracing-id"] = "asfdjhladhsfashf" + + mdin, _ := metadata.IncomingFrom(ctx) + if mdin["authorization"] != "bearer pleaseHackMe" { + return nil, orberrors.ErrUnauthorized + } + + return &proto.CallResponse{Msg: "Hello World"}, nil +} diff --git a/client/tests/proto/echo.pb.go b/client/tests/proto/echo.pb.go index f29af172..69ebd6d1 100644 --- a/client/tests/proto/echo.pb.go +++ b/client/tests/proto/echo.pb.go @@ -1,13 +1,12 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 -// protoc v4.25.1 +// protoc-gen-go v1.34.2 +// protoc v5.27.3 // source: echo.proto package proto import ( - _ "google.golang.org/genproto/googleapis/api/annotations" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -127,22 +126,21 @@ var File_echo_proto protoreflect.FileDescriptor var file_echo_proto_rawDesc = []byte{ 0x0a, 0x0a, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x65, 0x63, - 0x68, 0x6f, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, - 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x22, 0x21, 0x0a, 0x0b, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, - 0x61, 0x6d, 0x65, 0x22, 0x3a, 0x0a, 0x0c, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x6d, 0x73, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x32, - 0x57, 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x12, 0x4c, 0x0a, 0x04, 0x43, 0x61, - 0x6c, 0x6c, 0x12, 0x11, 0x2e, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x43, 0x61, 0x6c, - 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1d, 0x82, 0xd3, 0xe4, 0x93, 0x02, - 0x17, 0x3a, 0x01, 0x2a, 0x22, 0x12, 0x2f, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x53, 0x74, 0x72, 0x65, - 0x61, 0x6d, 0x73, 0x2f, 0x43, 0x61, 0x6c, 0x6c, 0x42, 0x0f, 0x5a, 0x0d, 0x2e, 0x2f, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x68, 0x6f, 0x22, 0x21, 0x0a, 0x0b, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x3a, 0x0a, 0x0c, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, + 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, + 0x64, 0x32, 0x71, 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x12, 0x2d, 0x0a, 0x04, + 0x43, 0x61, 0x6c, 0x6c, 0x12, 0x11, 0x2e, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x43, 0x61, 0x6c, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x43, + 0x61, 0x6c, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x37, 0x0a, 0x0e, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x43, 0x61, 0x6c, 0x6c, 0x12, 0x11, 0x2e, + 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x12, 0x2e, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x0f, 0x5a, 0x0d, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3b, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -158,15 +156,17 @@ func file_echo_proto_rawDescGZIP() []byte { } var file_echo_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_echo_proto_goTypes = []interface{}{ +var file_echo_proto_goTypes = []any{ (*CallRequest)(nil), // 0: echo.CallRequest (*CallResponse)(nil), // 1: echo.CallResponse } var file_echo_proto_depIdxs = []int32{ 0, // 0: echo.Streams.Call:input_type -> echo.CallRequest - 1, // 1: echo.Streams.Call:output_type -> echo.CallResponse - 1, // [1:2] is the sub-list for method output_type - 0, // [0:1] is the sub-list for method input_type + 0, // 1: echo.Streams.AuthorizedCall:input_type -> echo.CallRequest + 1, // 2: echo.Streams.Call:output_type -> echo.CallResponse + 1, // 3: echo.Streams.AuthorizedCall:output_type -> echo.CallResponse + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -178,7 +178,7 @@ func file_echo_proto_init() { return } if !protoimpl.UnsafeEnabled { - file_echo_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_echo_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*CallRequest); i { case 0: return &v.state @@ -190,7 +190,7 @@ func file_echo_proto_init() { return nil } } - file_echo_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_echo_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*CallResponse); i { case 0: return &v.state diff --git a/client/tests/proto/echo.proto b/client/tests/proto/echo.proto index d1d07822..80b4d606 100644 --- a/client/tests/proto/echo.proto +++ b/client/tests/proto/echo.proto @@ -2,17 +2,11 @@ syntax = "proto3"; package echo; -import "google/api/annotations.proto"; - option go_package = "./proto;proto"; service Streams { - rpc Call(CallRequest) returns (CallResponse) { - option (google.api.http) = { - post : "/echo.Streams/Call" - body : "*" - }; - } + rpc Call(CallRequest) returns (CallResponse); + rpc AuthorizedCall(CallRequest) returns (CallResponse); } message CallRequest { string name = 1; } diff --git a/client/tests/proto/echo_grpc.pb.go b/client/tests/proto/echo_grpc.pb.go deleted file mode 100644 index 8e85dd76..00000000 --- a/client/tests/proto/echo_grpc.pb.go +++ /dev/null @@ -1,105 +0,0 @@ -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.2.0 -// - protoc v4.25.1 -// source: echo.proto - -package proto - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 - -// StreamsClient is the client API for Streams service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type StreamsClient interface { - Call(ctx context.Context, in *CallRequest, opts ...grpc.CallOption) (*CallResponse, error) -} - -type streamsClient struct { - cc grpc.ClientConnInterface -} - -func NewStreamsClient(cc grpc.ClientConnInterface) StreamsClient { - return &streamsClient{cc} -} - -func (c *streamsClient) Call(ctx context.Context, in *CallRequest, opts ...grpc.CallOption) (*CallResponse, error) { - out := new(CallResponse) - err := c.cc.Invoke(ctx, "/echo.Streams/Call", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// StreamsServer is the server API for Streams service. -// All implementations must embed UnimplementedStreamsServer -// for forward compatibility -type StreamsServer interface { - Call(context.Context, *CallRequest) (*CallResponse, error) - mustEmbedUnimplementedStreamsServer() -} - -// UnimplementedStreamsServer must be embedded to have forward compatible implementations. -type UnimplementedStreamsServer struct { -} - -func (UnimplementedStreamsServer) Call(context.Context, *CallRequest) (*CallResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Call not implemented") -} -func (UnimplementedStreamsServer) mustEmbedUnimplementedStreamsServer() {} - -// UnsafeStreamsServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to StreamsServer will -// result in compilation errors. -type UnsafeStreamsServer interface { - mustEmbedUnimplementedStreamsServer() -} - -func RegisterStreamsServer(s grpc.ServiceRegistrar, srv StreamsServer) { - s.RegisterService(&Streams_ServiceDesc, srv) -} - -func _Streams_Call_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(CallRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(StreamsServer).Call(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/echo.Streams/Call", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(StreamsServer).Call(ctx, req.(*CallRequest)) - } - return interceptor(ctx, in, info, handler) -} - -// Streams_ServiceDesc is the grpc.ServiceDesc for Streams service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var Streams_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "echo.Streams", - HandlerType: (*StreamsServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "Call", - Handler: _Streams_Call_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "echo.proto", -} diff --git a/client/tests/proto/echo_drpc.pb.go b/client/tests/proto/echo_orb-drpc.pb.go similarity index 62% rename from client/tests/proto/echo_drpc.pb.go rename to client/tests/proto/echo_orb-drpc.pb.go index ffc667b6..633c857f 100644 --- a/client/tests/proto/echo_drpc.pb.go +++ b/client/tests/proto/echo_orb-drpc.pb.go @@ -1,5 +1,9 @@ -// Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.33 +// Code generated by protoc-gen-go-orb. DO NOT EDIT. +// +// version: +// - protoc-gen-go-orb v0.0.1 +// - protoc v5.27.3 +// // source: echo.proto package proto @@ -35,33 +39,9 @@ func (drpcEncoding_File_echo_proto) JSONUnmarshal(buf []byte, msg drpc.Message) return protojson.Unmarshal(buf, msg.(proto.Message)) } -type DRPCStreamsClient interface { - DRPCConn() drpc.Conn - - Call(ctx context.Context, in *CallRequest) (*CallResponse, error) -} - -type drpcStreamsClient struct { - cc drpc.Conn -} - -func NewDRPCStreamsClient(cc drpc.Conn) DRPCStreamsClient { - return &drpcStreamsClient{cc} -} - -func (c *drpcStreamsClient) DRPCConn() drpc.Conn { return c.cc } - -func (c *drpcStreamsClient) Call(ctx context.Context, in *CallRequest) (*CallResponse, error) { - out := new(CallResponse) - err := c.cc.Invoke(ctx, "/echo.Streams/Call", drpcEncoding_File_echo_proto{}, in, out) - if err != nil { - return nil, err - } - return out, nil -} - type DRPCStreamsServer interface { Call(context.Context, *CallRequest) (*CallResponse, error) + AuthorizedCall(context.Context, *CallRequest) (*CallResponse, error) } type DRPCStreamsUnimplementedServer struct{} @@ -70,9 +50,13 @@ func (s *DRPCStreamsUnimplementedServer) Call(context.Context, *CallRequest) (*C return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) } +func (s *DRPCStreamsUnimplementedServer) AuthorizedCall(context.Context, *CallRequest) (*CallResponse, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + type DRPCStreamsDescription struct{} -func (DRPCStreamsDescription) NumMethods() int { return 1 } +func (DRPCStreamsDescription) NumMethods() int { return 2 } func (DRPCStreamsDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) { switch n { @@ -85,15 +69,20 @@ func (DRPCStreamsDescription) Method(n int) (string, drpc.Encoding, drpc.Receive in1.(*CallRequest), ) }, DRPCStreamsServer.Call, true + case 1: + return "/echo.Streams/AuthorizedCall", drpcEncoding_File_echo_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCStreamsServer). + AuthorizedCall( + ctx, + in1.(*CallRequest), + ) + }, DRPCStreamsServer.AuthorizedCall, true default: return "", nil, nil, nil, false } } -func DRPCRegisterStreams(mux drpc.Mux, impl DRPCStreamsServer) error { - return mux.Register(impl, DRPCStreamsDescription{}) -} - type DRPCStreams_CallStream interface { drpc.Stream SendAndClose(*CallResponse) error @@ -103,9 +92,33 @@ type drpcStreams_CallStream struct { drpc.Stream } +func (x *drpcStreams_CallStream) GetStream() drpc.Stream { + return x.Stream +} + func (x *drpcStreams_CallStream) SendAndClose(m *CallResponse) error { if err := x.MsgSend(m, drpcEncoding_File_echo_proto{}); err != nil { return err } return x.CloseSend() } + +type DRPCStreams_AuthorizedCallStream interface { + drpc.Stream + SendAndClose(*CallResponse) error +} + +type drpcStreams_AuthorizedCallStream struct { + drpc.Stream +} + +func (x *drpcStreams_AuthorizedCallStream) GetStream() drpc.Stream { + return x.Stream +} + +func (x *drpcStreams_AuthorizedCallStream) SendAndClose(m *CallResponse) error { + if err := x.MsgSend(m, drpcEncoding_File_echo_proto{}); err != nil { + return err + } + return x.CloseSend() +} diff --git a/client/tests/proto/echo_orb-grpc.pb.go b/client/tests/proto/echo_orb-grpc.pb.go new file mode 100644 index 00000000..1af5fb46 --- /dev/null +++ b/client/tests/proto/echo_orb-grpc.pb.go @@ -0,0 +1,121 @@ +// Code generated by protoc-gen-go-orb-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-orb v0.0.1 +// - protoc v5.27.3 +// source: echo.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + Streams_Call_FullMethodName = "/echo.Streams/Call" + Streams_AuthorizedCall_FullMethodName = "/echo.Streams/AuthorizedCall" +) + +// StreamsServer is the server API for Streams service. +// All implementations should embed UnimplementedStreamsServer +// for forward compatibility. +type StreamsServer interface { + Call(context.Context, *CallRequest) (*CallResponse, error) + AuthorizedCall(context.Context, *CallRequest) (*CallResponse, error) +} + +// UnimplementedStreamsServer should be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedStreamsServer struct{} + +func (UnimplementedStreamsServer) Call(context.Context, *CallRequest) (*CallResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Call not implemented") +} +func (UnimplementedStreamsServer) AuthorizedCall(context.Context, *CallRequest) (*CallResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method AuthorizedCall not implemented") +} +func (UnimplementedStreamsServer) testEmbeddedByValue() {} + +// UnsafeStreamsServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to StreamsServer will +// result in compilation errors. +type UnsafeStreamsServer interface { + mustEmbedUnimplementedStreamsServer() +} + +func registerStreamsGRPCHandler(s grpc.ServiceRegistrar, srv StreamsServer) { + // If the following call panics, it indicates UnimplementedStreamsServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&Streams_ServiceDesc, srv) +} + +func _Streams_Call_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CallRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(StreamsServer).Call(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Streams_Call_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(StreamsServer).Call(ctx, req.(*CallRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Streams_AuthorizedCall_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CallRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(StreamsServer).AuthorizedCall(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Streams_AuthorizedCall_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(StreamsServer).AuthorizedCall(ctx, req.(*CallRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Streams_ServiceDesc is the grpc.ServiceDesc for Streams service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Streams_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "echo.Streams", + HandlerType: (*StreamsServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Call", + Handler: _Streams_Call_Handler, + }, + { + MethodName: "AuthorizedCall", + Handler: _Streams_AuthorizedCall_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "echo.proto", +} diff --git a/client/tests/proto/echo_orb.pb.go b/client/tests/proto/echo_orb.pb.go index 32e3838a..0e385290 100644 --- a/client/tests/proto/echo_orb.pb.go +++ b/client/tests/proto/echo_orb.pb.go @@ -1,8 +1,8 @@ -// Code generated by protoc-gen-go-orb-http. DO NOT EDIT. +// Code generated by protoc-gen-go-orb. DO NOT EDIT. // // version: -// - protoc-gen-go-orb-http v1.0.0 -// - protoc v4.25.1 +// - protoc-gen-go-orb v0.0.1 +// - protoc v5.27.3 // // Proto source: echo.proto @@ -11,59 +11,75 @@ package proto import ( "context" - "google.golang.org/grpc" - + "github.com/go-orb/go-orb/log" "github.com/go-orb/go-orb/server" + grpc "google.golang.org/grpc" + mdrpc "github.com/go-orb/plugins/server/drpc" mhertz "github.com/go-orb/plugins/server/hertz" mhttp "github.com/go-orb/plugins/server/http" ) -type orbStreamsHandler interface { - Call(context.Context, *CallRequest) (*CallResponse, error) - mustEmbedUnimplementedStreamsServer() -} +type StreamsHandler interface { + Call(ctx context.Context, req *CallRequest) (*CallResponse, error) -func registerStreamsHTTPHandler(srv *mhttp.ServerHTTP, handler orbStreamsHandler) { - r := srv.Router() - r.Post("/echo.Streams/Call", mhttp.NewGRPCHandler(srv, handler.Call)) -} - -func registerStreamsHertzHandler(srv *mhertz.Server, handler orbStreamsHandler) { - r := srv.Router() - r.POST("/echo.Streams/Call", mhertz.NewGRPCHandler(srv, handler.Call)) + AuthorizedCall(ctx context.Context, req *CallRequest) (*CallResponse, error) } -func registerStreamsDRPCHandler(srv *mdrpc.Server, handler orbStreamsHandler) { +func registerStreamsDRPCHandler(srv *mdrpc.Server, handler StreamsHandler) error { desc := DRPCStreamsDescription{} // Register with DRPC. r := srv.Router() - r.Register(handler, desc) + + // Register with the server/drpc(.Mux). + err := r.Register(handler, desc) + if err != nil { + return err + } // Add each endpoint name of this handler to the orb drpc server. for i := 0; i < desc.NumMethods(); i++ { name, _, _, _, _ := desc.Method(i) srv.AddEndpoint(name) } + + return nil +} + +// registerStreamsHTTPHandler registers the service to an HTTP server. +func registerStreamsHTTPHandler(srv *mhttp.ServerHTTP, handler StreamsHandler) { + r := srv.Router() + + r.Post("/echo.Streams/Call", mhttp.NewGRPCHandler(srv, handler.Call, "echo.Streams", "Call")) + r.Post("/echo.Streams/AuthorizedCall", mhttp.NewGRPCHandler(srv, handler.AuthorizedCall, "echo.Streams", "AuthorizedCall")) +} + +// registerStreamsHertzHandler registers the service to an Hertz server. +func registerStreamsHertzHandler(srv *mhertz.Server, handler StreamsHandler) { + r := srv.Router() + + r.POST("/echo.Streams/Call", mhertz.NewGRPCHandler(srv, handler.Call, "echo.Streams", "Call")) + r.POST("/echo.Streams/AuthorizedCall", mhertz.NewGRPCHandler(srv, handler.AuthorizedCall, "echo.Streams", "AuthorizedCall")) } // RegisterStreamsHandler will return a registration function that can be // provided to entrypoints as a handler registration. -func RegisterStreamsHandler(handler any) server.RegistrationFunc { +func RegisterStreamsHandler(handler StreamsHandler) server.RegistrationFunc { return server.RegistrationFunc(func(s any) { - switch srv := any(s).(type) { - case *mhttp.ServerHTTP: - registerStreamsHTTPHandler(srv, handler.(orbStreamsHandler)) - case *mhertz.Server: - registerStreamsHertzHandler(srv, handler.(orbStreamsHandler)) - case *mdrpc.Server: - registerStreamsDRPCHandler(srv, handler.(orbStreamsHandler)) + switch srv := s.(type) { + case grpc.ServiceRegistrar: - RegisterStreamsServer(srv, handler.(StreamsServer)) + registerStreamsGRPCHandler(srv, handler) + case *mdrpc.Server: + registerStreamsDRPCHandler(srv, handler) + case *mhertz.Server: + registerStreamsHertzHandler(srv, handler) + case *mhttp.ServerHTTP: + registerStreamsHTTPHandler(srv, handler) default: - // Maybe we should log here with slog global logger + log.Warn("No provider for this server found", "proto", "echo.proto", "handler", "Streams", "server", s) } }) } diff --git a/client/tests/proto/gen.go b/client/tests/proto/gen.go index 58dbc1d3..4fbdb7f6 100644 --- a/client/tests/proto/gen.go +++ b/client/tests/proto/gen.go @@ -1,9 +1,5 @@ // Package proto ... package proto -// Download Google proto HTTP annotation libs -//go:generate wget -q -O google/api/annotations.proto https://raw.githubusercontent.com/googleapis/googleapis/master/google/api/annotations.proto -//go:generate wget -q -O google/api/http.proto https://raw.githubusercontent.com/googleapis/googleapis/master/google/api/http.proto - // Generate proto files -//go:generate protoc -I . --go-grpc_out=paths=source_relative:. --go-micro-http_out=paths=source_relative:. --go_out=paths=source_relative:. --go-drpc_out=paths=source_relative:. ./echo.proto +//go:generate protoc -I . --go-orb_out=paths=source_relative:. --go-orb_opt=supported_servers=drpc;grpc;http;hertz echo.proto diff --git a/client/tests/proto/google/api/annotations.proto b/client/tests/proto/google/api/annotations.proto deleted file mode 100644 index efdab3db..00000000 --- a/client/tests/proto/google/api/annotations.proto +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2015 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package google.api; - -import "google/api/http.proto"; -import "google/protobuf/descriptor.proto"; - -option go_package = "google.golang.org/genproto/googleapis/api/annotations;annotations"; -option java_multiple_files = true; -option java_outer_classname = "AnnotationsProto"; -option java_package = "com.google.api"; -option objc_class_prefix = "GAPI"; - -extend google.protobuf.MethodOptions { - // See `HttpRule`. - HttpRule http = 72295728; -} diff --git a/client/tests/proto/google/api/http.proto b/client/tests/proto/google/api/http.proto deleted file mode 100644 index 31d867a2..00000000 --- a/client/tests/proto/google/api/http.proto +++ /dev/null @@ -1,379 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package google.api; - -option cc_enable_arenas = true; -option go_package = "google.golang.org/genproto/googleapis/api/annotations;annotations"; -option java_multiple_files = true; -option java_outer_classname = "HttpProto"; -option java_package = "com.google.api"; -option objc_class_prefix = "GAPI"; - -// Defines the HTTP configuration for an API service. It contains a list of -// [HttpRule][google.api.HttpRule], each specifying the mapping of an RPC method -// to one or more HTTP REST API methods. -message Http { - // A list of HTTP configuration rules that apply to individual API methods. - // - // **NOTE:** All service configuration rules follow "last one wins" order. - repeated HttpRule rules = 1; - - // When set to true, URL path parameters will be fully URI-decoded except in - // cases of single segment matches in reserved expansion, where "%2F" will be - // left encoded. - // - // The default behavior is to not decode RFC 6570 reserved characters in multi - // segment matches. - bool fully_decode_reserved_expansion = 2; -} - -// # gRPC Transcoding -// -// gRPC Transcoding is a feature for mapping between a gRPC method and one or -// more HTTP REST endpoints. It allows developers to build a single API service -// that supports both gRPC APIs and REST APIs. Many systems, including [Google -// APIs](https://github.com/googleapis/googleapis), -// [Cloud Endpoints](https://cloud.google.com/endpoints), [gRPC -// Gateway](https://github.com/grpc-ecosystem/grpc-gateway), -// and [Envoy](https://github.com/envoyproxy/envoy) proxy support this feature -// and use it for large scale production services. -// -// `HttpRule` defines the schema of the gRPC/REST mapping. The mapping specifies -// how different portions of the gRPC request message are mapped to the URL -// path, URL query parameters, and HTTP request body. It also controls how the -// gRPC response message is mapped to the HTTP response body. `HttpRule` is -// typically specified as an `google.api.http` annotation on the gRPC method. -// -// Each mapping specifies a URL path template and an HTTP method. The path -// template may refer to one or more fields in the gRPC request message, as long -// as each field is a non-repeated field with a primitive (non-message) type. -// The path template controls how fields of the request message are mapped to -// the URL path. -// -// Example: -// -// service Messaging { -// rpc GetMessage(GetMessageRequest) returns (Message) { -// option (google.api.http) = { -// get: "/v1/{name=messages/*}" -// }; -// } -// } -// message GetMessageRequest { -// string name = 1; // Mapped to URL path. -// } -// message Message { -// string text = 1; // The resource content. -// } -// -// This enables an HTTP REST to gRPC mapping as below: -// -// HTTP | gRPC -// -----|----- -// `GET /v1/messages/123456` | `GetMessage(name: "messages/123456")` -// -// Any fields in the request message which are not bound by the path template -// automatically become HTTP query parameters if there is no HTTP request body. -// For example: -// -// service Messaging { -// rpc GetMessage(GetMessageRequest) returns (Message) { -// option (google.api.http) = { -// get:"/v1/messages/{message_id}" -// }; -// } -// } -// message GetMessageRequest { -// message SubMessage { -// string subfield = 1; -// } -// string message_id = 1; // Mapped to URL path. -// int64 revision = 2; // Mapped to URL query parameter `revision`. -// SubMessage sub = 3; // Mapped to URL query parameter `sub.subfield`. -// } -// -// This enables a HTTP JSON to RPC mapping as below: -// -// HTTP | gRPC -// -----|----- -// `GET /v1/messages/123456?revision=2&sub.subfield=foo` | -// `GetMessage(message_id: "123456" revision: 2 sub: SubMessage(subfield: -// "foo"))` -// -// Note that fields which are mapped to URL query parameters must have a -// primitive type or a repeated primitive type or a non-repeated message type. -// In the case of a repeated type, the parameter can be repeated in the URL -// as `...?param=A¶m=B`. In the case of a message type, each field of the -// message is mapped to a separate parameter, such as -// `...?foo.a=A&foo.b=B&foo.c=C`. -// -// For HTTP methods that allow a request body, the `body` field -// specifies the mapping. Consider a REST update method on the -// message resource collection: -// -// service Messaging { -// rpc UpdateMessage(UpdateMessageRequest) returns (Message) { -// option (google.api.http) = { -// patch: "/v1/messages/{message_id}" -// body: "message" -// }; -// } -// } -// message UpdateMessageRequest { -// string message_id = 1; // mapped to the URL -// Message message = 2; // mapped to the body -// } -// -// The following HTTP JSON to RPC mapping is enabled, where the -// representation of the JSON in the request body is determined by -// protos JSON encoding: -// -// HTTP | gRPC -// -----|----- -// `PATCH /v1/messages/123456 { "text": "Hi!" }` | `UpdateMessage(message_id: -// "123456" message { text: "Hi!" })` -// -// The special name `*` can be used in the body mapping to define that -// every field not bound by the path template should be mapped to the -// request body. This enables the following alternative definition of -// the update method: -// -// service Messaging { -// rpc UpdateMessage(Message) returns (Message) { -// option (google.api.http) = { -// patch: "/v1/messages/{message_id}" -// body: "*" -// }; -// } -// } -// message Message { -// string message_id = 1; -// string text = 2; -// } -// -// -// The following HTTP JSON to RPC mapping is enabled: -// -// HTTP | gRPC -// -----|----- -// `PATCH /v1/messages/123456 { "text": "Hi!" }` | `UpdateMessage(message_id: -// "123456" text: "Hi!")` -// -// Note that when using `*` in the body mapping, it is not possible to -// have HTTP parameters, as all fields not bound by the path end in -// the body. This makes this option more rarely used in practice when -// defining REST APIs. The common usage of `*` is in custom methods -// which don't use the URL at all for transferring data. -// -// It is possible to define multiple HTTP methods for one RPC by using -// the `additional_bindings` option. Example: -// -// service Messaging { -// rpc GetMessage(GetMessageRequest) returns (Message) { -// option (google.api.http) = { -// get: "/v1/messages/{message_id}" -// additional_bindings { -// get: "/v1/users/{user_id}/messages/{message_id}" -// } -// }; -// } -// } -// message GetMessageRequest { -// string message_id = 1; -// string user_id = 2; -// } -// -// This enables the following two alternative HTTP JSON to RPC mappings: -// -// HTTP | gRPC -// -----|----- -// `GET /v1/messages/123456` | `GetMessage(message_id: "123456")` -// `GET /v1/users/me/messages/123456` | `GetMessage(user_id: "me" message_id: -// "123456")` -// -// ## Rules for HTTP mapping -// -// 1. Leaf request fields (recursive expansion nested messages in the request -// message) are classified into three categories: -// - Fields referred by the path template. They are passed via the URL path. -// - Fields referred by the [HttpRule.body][google.api.HttpRule.body]. They -// are passed via the HTTP -// request body. -// - All other fields are passed via the URL query parameters, and the -// parameter name is the field path in the request message. A repeated -// field can be represented as multiple query parameters under the same -// name. -// 2. If [HttpRule.body][google.api.HttpRule.body] is "*", there is no URL -// query parameter, all fields -// are passed via URL path and HTTP request body. -// 3. If [HttpRule.body][google.api.HttpRule.body] is omitted, there is no HTTP -// request body, all -// fields are passed via URL path and URL query parameters. -// -// ### Path template syntax -// -// Template = "/" Segments [ Verb ] ; -// Segments = Segment { "/" Segment } ; -// Segment = "*" | "**" | LITERAL | Variable ; -// Variable = "{" FieldPath [ "=" Segments ] "}" ; -// FieldPath = IDENT { "." IDENT } ; -// Verb = ":" LITERAL ; -// -// The syntax `*` matches a single URL path segment. The syntax `**` matches -// zero or more URL path segments, which must be the last part of the URL path -// except the `Verb`. -// -// The syntax `Variable` matches part of the URL path as specified by its -// template. A variable template must not contain other variables. If a variable -// matches a single path segment, its template may be omitted, e.g. `{var}` -// is equivalent to `{var=*}`. -// -// The syntax `LITERAL` matches literal text in the URL path. If the `LITERAL` -// contains any reserved character, such characters should be percent-encoded -// before the matching. -// -// If a variable contains exactly one path segment, such as `"{var}"` or -// `"{var=*}"`, when such a variable is expanded into a URL path on the client -// side, all characters except `[-_.~0-9a-zA-Z]` are percent-encoded. The -// server side does the reverse decoding. Such variables show up in the -// [Discovery -// Document](https://developers.google.com/discovery/v1/reference/apis) as -// `{var}`. -// -// If a variable contains multiple path segments, such as `"{var=foo/*}"` -// or `"{var=**}"`, when such a variable is expanded into a URL path on the -// client side, all characters except `[-_.~/0-9a-zA-Z]` are percent-encoded. -// The server side does the reverse decoding, except "%2F" and "%2f" are left -// unchanged. Such variables show up in the -// [Discovery -// Document](https://developers.google.com/discovery/v1/reference/apis) as -// `{+var}`. -// -// ## Using gRPC API Service Configuration -// -// gRPC API Service Configuration (service config) is a configuration language -// for configuring a gRPC service to become a user-facing product. The -// service config is simply the YAML representation of the `google.api.Service` -// proto message. -// -// As an alternative to annotating your proto file, you can configure gRPC -// transcoding in your service config YAML files. You do this by specifying a -// `HttpRule` that maps the gRPC method to a REST endpoint, achieving the same -// effect as the proto annotation. This can be particularly useful if you -// have a proto that is reused in multiple services. Note that any transcoding -// specified in the service config will override any matching transcoding -// configuration in the proto. -// -// Example: -// -// http: -// rules: -// # Selects a gRPC method and applies HttpRule to it. -// - selector: example.v1.Messaging.GetMessage -// get: /v1/messages/{message_id}/{sub.subfield} -// -// ## Special notes -// -// When gRPC Transcoding is used to map a gRPC to JSON REST endpoints, the -// proto to JSON conversion must follow the [proto3 -// specification](https://developers.google.com/protocol-buffers/docs/proto3#json). -// -// While the single segment variable follows the semantics of -// [RFC 6570](https://tools.ietf.org/html/rfc6570) Section 3.2.2 Simple String -// Expansion, the multi segment variable **does not** follow RFC 6570 Section -// 3.2.3 Reserved Expansion. The reason is that the Reserved Expansion -// does not expand special characters like `?` and `#`, which would lead -// to invalid URLs. As the result, gRPC Transcoding uses a custom encoding -// for multi segment variables. -// -// The path variables **must not** refer to any repeated or mapped field, -// because client libraries are not capable of handling such variable expansion. -// -// The path variables **must not** capture the leading "/" character. The reason -// is that the most common use case "{var}" does not capture the leading "/" -// character. For consistency, all path variables must share the same behavior. -// -// Repeated message fields must not be mapped to URL query parameters, because -// no client library can support such complicated mapping. -// -// If an API needs to use a JSON array for request or response body, it can map -// the request or response body to a repeated field. However, some gRPC -// Transcoding implementations may not support this feature. -message HttpRule { - // Selects a method to which this rule applies. - // - // Refer to [selector][google.api.DocumentationRule.selector] for syntax - // details. - string selector = 1; - - // Determines the URL pattern is matched by this rules. This pattern can be - // used with any of the {get|put|post|delete|patch} methods. A custom method - // can be defined using the 'custom' field. - oneof pattern { - // Maps to HTTP GET. Used for listing and getting information about - // resources. - string get = 2; - - // Maps to HTTP PUT. Used for replacing a resource. - string put = 3; - - // Maps to HTTP POST. Used for creating a resource or performing an action. - string post = 4; - - // Maps to HTTP DELETE. Used for deleting a resource. - string delete = 5; - - // Maps to HTTP PATCH. Used for updating a resource. - string patch = 6; - - // The custom pattern is used for specifying an HTTP method that is not - // included in the `pattern` field, such as HEAD, or "*" to leave the - // HTTP method unspecified for this rule. The wild-card rule is useful - // for services that provide content to Web (HTML) clients. - CustomHttpPattern custom = 8; - } - - // The name of the request field whose value is mapped to the HTTP request - // body, or `*` for mapping all request fields not captured by the path - // pattern to the HTTP body, or omitted for not having any HTTP request body. - // - // NOTE: the referred field must be present at the top-level of the request - // message type. - string body = 7; - - // Optional. The name of the response field whose value is mapped to the HTTP - // response body. When omitted, the entire response message will be used - // as the HTTP response body. - // - // NOTE: The referred field must be present at the top-level of the response - // message type. - string response_body = 12; - - // Additional HTTP bindings for the selector. Nested bindings must - // not contain an `additional_bindings` field themselves (that is, - // the nesting may only be one level deep). - repeated HttpRule additional_bindings = 11; -} - -// A custom pattern is used for defining custom HTTP verb. -message CustomHttpPattern { - // The name of this custom HTTP verb. - string kind = 1; - - // The path matched by this custom verb. - string path = 2; -} diff --git a/client/tests/tests.go b/client/tests/tests.go index a1cbb9a6..39b61249 100644 --- a/client/tests/tests.go +++ b/client/tests/tests.go @@ -13,6 +13,8 @@ import ( "github.com/go-orb/go-orb/log" "github.com/go-orb/go-orb/registry" "github.com/go-orb/go-orb/types" + "github.com/go-orb/go-orb/util/metadata" + "github.com/go-orb/go-orb/util/orberrors" "github.com/go-orb/plugins/client/tests/proto" "github.com/stretchr/testify/suite" "golang.org/x/exp/slices" @@ -166,13 +168,13 @@ func (s *TestSuite) SetupSuite() { s.clientName = types.ServiceName("org.orb.svc.client") // Logger - logger, err := log.ProvideLogger(s.clientName, cfgData) + logger, err := log.Provide(s.clientName, cfgData) s.Require().NoError(err, "while setting up logger") s.Require().NoError(logger.Start()) s.logger = logger // Registry - reg, err := registry.ProvideRegistry(s.clientName, version, cfgData, logger) + reg, err := registry.Provide(s.clientName, version, cfgData, logger) if err != nil { s.Require().NoError(err, "while creating a registry") } @@ -181,7 +183,7 @@ func (s *TestSuite) SetupSuite() { s.registry = reg // Client - c, err := client.ProvideClient(s.clientName, cfgData, logger, reg) + c, err := client.Provide(s.clientName, cfgData, logger, reg) if err != nil { s.Require().NoError(err, "while creating a client") } @@ -312,3 +314,40 @@ func (s *TestSuite) TestRunRequests() { }) } } + +// TestFailingAuthorization tests an authorization call that must fail. +func (s *TestSuite) TestFailingAuthorization() { + responseMd := make(map[string]string) + ctx := context.Background() + _, err := client.Call[proto.CallResponse]( + ctx, + s.client, + string(ServiceName), + "echo.Streams/AuthorizedCall", + &proto.CallRequest{Name: "empty"}, + client.Headers(responseMd), + ) + s.Require().ErrorIs(err, orberrors.ErrUnauthorized) +} + +// TestMetadata checks if metadata gets transported over the wire. +func (s *TestSuite) TestMetadata() { + ctx := context.Background() + ctx, md := metadata.WithOutgoing(ctx) + md["authorization"] = "bearer pleaseHackMe" + + responseMd := make(map[string]string) + _, err := client.Call[proto.CallResponse]( + ctx, + s.client, + string(ServiceName), + "echo.Streams/AuthorizedCall", + &proto.CallRequest{Name: "empty"}, + client.Headers(responseMd), + ) + s.Require().NoError(err) + + rspHandler, ok := responseMd["tracing-id"] + s.Require().True(ok, "Transport does not transport metadata - tracing-id") + s.Require().Equal("asfdjhladhsfashf", rspHandler) +} diff --git a/client/tests/orb_transport/drpc/drpc_test.go b/client/tests/transport/drpc/drpc_test.go similarity index 100% rename from client/tests/orb_transport/drpc/drpc_test.go rename to client/tests/transport/drpc/drpc_test.go diff --git a/client/tests/orb_transport/grpc/grpc_test.go b/client/tests/transport/grpc/grpc_test.go similarity index 100% rename from client/tests/orb_transport/grpc/grpc_test.go rename to client/tests/transport/grpc/grpc_test.go diff --git a/client/tests/orb_transport/h2c/h2c_test.go b/client/tests/transport/h2c/h2c_test.go similarity index 100% rename from client/tests/orb_transport/h2c/h2c_test.go rename to client/tests/transport/h2c/h2c_test.go diff --git a/client/tests/orb_transport/hertzh2c/hertzh2c_test.go b/client/tests/transport/hertzh2c/hertzh2c_test.go similarity index 100% rename from client/tests/orb_transport/hertzh2c/hertzh2c_test.go rename to client/tests/transport/hertzh2c/hertzh2c_test.go diff --git a/client/tests/orb_transport/hertzhttp/hertzhttp_test.go b/client/tests/transport/hertzhttp/hertzhttp_test.go similarity index 100% rename from client/tests/orb_transport/hertzhttp/hertzhttp_test.go rename to client/tests/transport/hertzhttp/hertzhttp_test.go diff --git a/client/tests/orb_transport/http/http_test.go b/client/tests/transport/http/http_test.go similarity index 100% rename from client/tests/orb_transport/http/http_test.go rename to client/tests/transport/http/http_test.go diff --git a/client/tests/orb_transport/http3/http3_test.go b/client/tests/transport/http3/http3_test.go similarity index 100% rename from client/tests/orb_transport/http3/http3_test.go rename to client/tests/transport/http3/http3_test.go diff --git a/client/tests/orb_transport/https/https_test.go b/client/tests/transport/https/https_test.go similarity index 100% rename from client/tests/orb_transport/https/https_test.go rename to client/tests/transport/https/https_test.go diff --git a/event/natsjs/natsjs.go b/event/natsjs/natsjs.go index 0c2313d1..239be3c5 100644 --- a/event/natsjs/natsjs.go +++ b/event/natsjs/natsjs.go @@ -21,7 +21,7 @@ import ( var _ event.Events = (*NatsJS)(nil) type replyMessage struct { - Metadata metadata.Metadata `json:"metadata"` + Metadata map[string]string `json:"metadata"` Data []byte `json:"data"` Err error `json:"err"` } @@ -139,9 +139,9 @@ func (n *NatsJS) HandleRequest( } req.SetReplyFunc(func(ctx context.Context, result []byte, inErr error) { - md, ok := metadata.From(ctx) + md, ok := metadata.OutgoingFrom(ctx) if !ok { - md = make(metadata.Metadata) + md = make(map[string]string) } reply := &replyMessage{ diff --git a/server/cmd/protoc-gen-go-orb/orb/template.go b/server/cmd/protoc-gen-go-orb/orb/template.go index b80329bb..1725bb96 100644 --- a/server/cmd/protoc-gen-go-orb/orb/template.go +++ b/server/cmd/protoc-gen-go-orb/orb/template.go @@ -23,18 +23,20 @@ import ( {{- range .Services }} type {{.Type}}Handler interface { - {{.Type}}(ctx context.Context, req *Req) (*Resp, error) + {{- range .Methods }} + {{.Name}}(ctx context.Context, req *{{.Request}}) (*{{.Reply}}, error) + {{ end -}} } {{- if $.ServerDRPC }} func register{{.Type}}DRPCHandler(srv *mdrpc.Server, handler {{.Type}}Handler) error { - desc := DRPCEchoDescription{} + desc := DRPC{{.Type}}Description{} // Register with DRPC. r := srv.Router() - // Register with the drpcmux. + // Register with the server/drpc(.Mux). err := r.Register(handler, desc) if err != nil { return err @@ -56,8 +58,8 @@ func register{{.Type}}DRPCHandler(srv *mdrpc.Server, handler {{.Type}}Handler) e // register{{.Type}}HTTPHandler registers the service to an HTTP server. func register{{.Type}}HTTPHandler(srv *mhttp.ServerHTTP, handler {{.Type}}Handler) { r := srv.Router() - {{- range .Methods}} - r.{{.Method}}("{{.Path}}", mhttp.NewGRPCHandler(srv, handler.{{.Name}})) + {{$method := .}}{{- range .Methods}} + r.{{.Method}}("{{.Path}}", mhttp.NewGRPCHandler(srv, handler.{{.Name}}, "{{$method.Name}}", "{{.Name}}")) {{- end}} } @@ -68,16 +70,16 @@ func register{{.Type}}HTTPHandler(srv *mhttp.ServerHTTP, handler {{.Type}}Handle // register{{.Type}}HertzHandler registers the service to an Hertz server. func register{{.Type}}HertzHandler(srv *mhertz.Server, handler {{.Type}}Handler) { r := srv.Router() - {{- range .Methods}} - r.{{.MethodUpper}}("{{.Path}}", mhertz.NewGRPCHandler(srv, handler.{{.Name}})) + {{$method := .}}{{- range .Methods}} + r.{{.MethodUpper}}("{{.Path}}", mhertz.NewGRPCHandler(srv, handler.{{.Name}}, "{{$method.Name}}", "{{.Name}}")) {{- end}} } {{ end -}} -// Register{{.Type}}Service will return a registration function that can be +// Register{{.Type}}Handler will return a registration function that can be // provided to entrypoints as a handler registration. -func Register{{.Type}}Service(handler {{.Type}}Handler) server.RegistrationFunc { +func Register{{.Type}}Handler(handler {{.Type}}Handler) server.RegistrationFunc { return server.RegistrationFunc(func(s any) { switch srv := s.(type) { {{ if $.ServerGRPC }} @@ -116,7 +118,7 @@ type protoFile struct { } func (p protoFile) Render() string { - tmpl, err := template.New("http").Parse(strings.TrimSpace(orbTemplate)) + tmpl, err := template.New("orb").Parse(strings.TrimSpace(orbTemplate)) if err != nil { panic(err) } diff --git a/server/cmd/protoc-gen-go-orb/orbdrpc/drpc.go b/server/cmd/protoc-gen-go-orb/orbdrpc/drpc.go index f532deba..e14f3ea4 100644 --- a/server/cmd/protoc-gen-go-orb/orbdrpc/drpc.go +++ b/server/cmd/protoc-gen-go-orb/orbdrpc/drpc.go @@ -265,11 +265,6 @@ func (d *drpc) generateService(service *protogen.Service) { d.P("}") d.P() - // Registration helper - d.P("func DRPCRegister", service.GoName, "(mux ", d.Ident("storj.io/drpc", "Mux"), ", impl ", d.ServerIface(service), ") error {") - d.P("return mux.Register(impl, ", d.ServerDesc(service), "{})") - d.P("}") - // Server methods for _, method := range service.Methods { d.generateServerMethod(method) diff --git a/server/drpc/config.go b/server/drpc/config.go index f3ba7a1f..d1608191 100644 --- a/server/drpc/config.go +++ b/server/drpc/config.go @@ -5,8 +5,7 @@ import ( "net" "time" - "log/slog" - + "github.com/go-orb/go-orb/log" "github.com/go-orb/go-orb/server" "github.com/google/uuid" ) @@ -75,12 +74,12 @@ type Config struct { // Handlers global, and setting them explicitly in the config. HandlerRegistrations server.HandlerRegistrations `json:"handlers" yaml:"handlers"` + // Middlewares is a list of middleware to use. + Middlewares []string `json:"middlewares" yaml:"middlewares"` + // Logger allows you to dynamically change the log level and plugin for a // specific entrypoint. - Logger struct { - Level slog.Level `json:"level,omitempty" yaml:"level,omitempty"` // TODO(davincible): change with custom level - Plugin string `json:"plugin,omitempty" yaml:"plugin,omitempty"` - } `json:"logger" yaml:"logger"` + Logger log.Config `json:"logger" yaml:"logger"` } // NewConfig will create a new default config for the entrypoint. @@ -89,6 +88,7 @@ func NewConfig(options ...Option) *Config { Name: "dprc-" + uuid.NewString(), Address: DefaultAddress, HandlerRegistrations: make(server.HandlerRegistrations), + Middlewares: []string{}, } cfg.ApplyOptions(options...) @@ -163,7 +163,7 @@ func WithRegistration(name string, registration server.RegistrationFunc) Option } // WithLogLevel changes the log level from the inherited logger. -func WithLogLevel(level slog.Level) Option { +func WithLogLevel(level string) Option { return func(c *Config) { c.Logger.Level = level } @@ -211,3 +211,14 @@ func WithEntrypoint(options ...Option) server.Option { } } } + +// WithMiddleware appends middlewares to the server. +// You can use any standard Go HTTP middleware. +// +// Each middlware is uniquely identified with a name. The name provided here +// can be used to dynamically add middlware to an entrypoint in a config. +func WithMiddleware(middlewares ...string) Option { + return func(c *Config) { + c.Middlewares = append(c.Middlewares, middlewares...) + } +} diff --git a/server/drpc/drpc.go b/server/drpc/drpc.go index 43f3ffab..d0fa07bb 100644 --- a/server/drpc/drpc.go +++ b/server/drpc/drpc.go @@ -7,7 +7,6 @@ import ( "log/slog" "net" - "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" "github.com/go-orb/go-orb/log" @@ -16,6 +15,8 @@ import ( "github.com/go-orb/go-orb/types" "github.com/go-orb/go-orb/util/addr" "github.com/google/uuid" + + "github.com/gammazero/workerpool" ) var _ orbserver.Entrypoint = (*Server)(nil) @@ -32,15 +33,18 @@ type Server struct { ctx context.Context cancelFunc context.CancelFunc - mux *drpcmux.Mux + mux *Mux server *drpcserver.Server + middlewares []orbserver.Middleware + endpoints []string // entrypointID is the entrypointID (uuid) of this entrypoint in the registry. entrypointID string started bool + wp *workerpool.WorkerPool } // Start will create the listeners and start the server on the entrypoint. @@ -52,7 +56,7 @@ func (s *Server) Start() error { s.logger.Info("Starting", "address", s.config.Address) // create a drpc RPC mux - s.mux = drpcmux.New() + s.mux = newMux(s) // Register handlers. for _, h := range s.config.HandlerRegistrations { @@ -71,16 +75,35 @@ func (s *Server) Start() error { } } - go func(s *Server, listener net.Listener) { - err := s.server.Serve(s.ctx, listener) - s.logger.Error("While serving", "error", err) - }(s, listener) + s.ctx, s.cancelFunc = context.WithCancel(context.Background()) + + s.wp = workerpool.New(256) + go s.run(s.ctx, listener) s.started = true return s.registryRegister() } +func (s *Server) run(ctx context.Context, lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + + continue + } + + s.wp.Submit(func() { + if err = s.server.ServeOne(ctx, conn); err != nil { + s.logger.Error("while serving", "err", err, "addr", conn.RemoteAddr()) + } + }) + } +} + // Stop will stop the Hertz server(s). func (s *Server) Stop(_ context.Context) error { if !s.started { @@ -89,6 +112,7 @@ func (s *Server) Stop(_ context.Context) error { // Stops the dRPC Server. s.cancelFunc() + s.wp.Stop() return s.registryDeregister() } @@ -149,7 +173,7 @@ func (s *Server) Type() string { } // Router returns the drpc mux. -func (s *Server) Router() *drpcmux.Mux { +func (s *Server) Router() *Mux { return s.mux } @@ -196,9 +220,9 @@ func (s *Server) registryDeregister() error { return s.registry.Deregister(rService) } -// ProvideServer creates a new entrypoint for a single address. You can create +// Provide creates a new entrypoint for a single address. You can create // multiple entrypoints for multiple addresses and ports. -func ProvideServer( +func Provide( _ types.ServiceName, logger log.Logger, reg registry.Type, @@ -222,12 +246,23 @@ func ProvideServer( ctx, cancelFunc := context.WithCancel(context.Background()) + mws := []orbserver.Middleware{} + + for _, m := range cfg.Middlewares { + if mw, ok := orbserver.Middlewares.Get(m); ok { + mws = append(mws, mw) + } else { + logger.Error("unknown middleware given", "middleware", m) + } + } + entrypoint := Server{ - config: cfg, - logger: logger, - registry: reg, - ctx: ctx, - cancelFunc: cancelFunc, + config: cfg, + logger: logger, + registry: reg, + middlewares: mws, + ctx: ctx, + cancelFunc: cancelFunc, } return &entrypoint, nil diff --git a/server/drpc/go.mod b/server/drpc/go.mod index 18927599..9190ac5a 100644 --- a/server/drpc/go.mod +++ b/server/drpc/go.mod @@ -5,6 +5,7 @@ go 1.23 toolchain go1.23.0 require ( + github.com/gammazero/workerpool v1.1.3 github.com/go-orb/go-orb v0.0.0-20240831182006-95fb90a9afe7 github.com/google/uuid v1.6.0 storj.io/drpc v0.0.34 @@ -12,6 +13,7 @@ require ( require ( github.com/cornelk/hashmap v1.0.8 // indirect + github.com/gammazero/deque v0.2.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/server/drpc/go.sum b/server/drpc/go.sum index 5cc2aa10..605a7f4e 100644 --- a/server/drpc/go.sum +++ b/server/drpc/go.sum @@ -3,6 +3,10 @@ github.com/cornelk/hashmap v1.0.8/go.mod h1:RfZb7JO3RviW/rT6emczVuC/oxpdz4UsSB2L github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gammazero/deque v0.2.0 h1:SkieyNB4bg2/uZZLxvya0Pq6diUlwx7m2TeT7GAIWaA= +github.com/gammazero/deque v0.2.0/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= +github.com/gammazero/workerpool v1.1.3 h1:WixN4xzukFoN0XSeXF6puqEqFTl2mECI9S6W44HWy9Q= +github.com/gammazero/workerpool v1.1.3/go.mod h1:wPjyBLDbyKnUn2XwwyD3EEwo9dHutia9/fwNmSHWACc= github.com/go-orb/go-orb v0.0.0-20240831182006-95fb90a9afe7 h1:9ZCjLkvUlwDSwIhkVRij0nLuP38BMJAJCv4cX1TP3Mg= github.com/go-orb/go-orb v0.0.0-20240831182006-95fb90a9afe7/go.mod h1:DdL+EYRtGU8OMU4H7NSQvRxRVzL+GBliOVWC0QHKuy0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -26,6 +30,8 @@ github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs= github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/server/drpc/handle_rpc.go b/server/drpc/handle_rpc.go new file mode 100644 index 00000000..4b8710e3 --- /dev/null +++ b/server/drpc/handle_rpc.go @@ -0,0 +1,112 @@ +// Copyright (C) 2020 Storj Labs, Inc. +// Copyright (C) 2024 go-orb Authors. +// See LICENSE for copying information. + +package drpc + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + + "github.com/go-orb/go-orb/util/metadata" + "github.com/go-orb/go-orb/util/orberrors" + "github.com/go-orb/plugins/server/drpc/message" + "github.com/zeebo/errs" + proto "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + "storj.io/drpc" + "storj.io/drpc/drpcerr" + "storj.io/drpc/drpcmetadata" +) + +type streamWrapper struct { + drpc.Stream + ctx context.Context +} + +func (s *streamWrapper) Context() context.Context { return s.ctx } + +// HandleRPC handles the rpc that has been requested by the stream. +func (m *Mux) HandleRPC(stream drpc.Stream, rpc string) (err error) { + data, ok := m.rpcs[rpc] + if !ok { + return drpc.ProtocolError.New("unknown rpc: %q", rpc) + } + + req := interface{}(stream) + + if data.in1 != streamType { + msg, ok := reflect.New(data.in1.Elem()).Interface().(drpc.Message) + if !ok { + return drpc.InternalError.New("invalid rpc input type") + } + + if err := stream.MsgRecv(msg, data.enc); err != nil { + return errs.Wrap(err) + } + + req = msg + } + + ctx := stream.Context() + + ctx = metadata.EnsureIncoming(ctx) + ctx = metadata.EnsureOutgoing(ctx) + + dMeta, ok := drpcmetadata.Get(ctx) + if !ok { + dMeta = make(map[string]string) + } + + incMd, _ := metadata.IncomingFrom(ctx) + for k, v := range dMeta { + incMd[k] = v + } + + fmSplit := strings.Split(rpc, "/") + + if len(fmSplit) >= 3 { + incMd[metadata.Service] = fmSplit[1] + incMd[metadata.Method] = fmSplit[2] + } + + stream = &streamWrapper{Stream: stream, ctx: ctx} + + // Apply middleware. + h := func(ctx context.Context, req any) (any, error) { + // The actual call. + return data.receiver(data.srv, ctx, req, stream) + } + for _, m := range m.orbSrv.middlewares { + h = m.Call(h) + } + + // Calls all middlewares until the actual call. + out, err := h(ctx, req) + + switch { + case err != nil: + oErr := orberrors.From(err) + + if oErr.Wrapped != nil { + return drpcerr.WithCode(fmt.Errorf("%s: %s", oErr.Message, oErr.Wrapped.Error()), uint64(oErr.Code)) + } + + return drpcerr.WithCode(errors.New(oErr.Message), uint64(oErr.Code)) + case out != nil && !reflect.ValueOf(out).IsNil(): + outMd, _ := metadata.OutgoingFrom(ctx) + + outData, err := anypb.New(out.(proto.Message)) + if err != nil { + return errs.Wrap(err) + } + + return stream.MsgSend(&message.Response{Metadata: outMd, Data: outData}, data.enc) + default: + return stream.CloseSend() + } +} diff --git a/server/drpc/message/message.pb.go b/server/drpc/message/message.pb.go new file mode 100644 index 00000000..1cde34f9 --- /dev/null +++ b/server/drpc/message/message.pb.go @@ -0,0 +1,165 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.34.2 +// protoc v5.27.3 +// source: message.proto + +package message + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + anypb "google.golang.org/protobuf/types/known/anypb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Response struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Metadata map[string]string `protobuf:"bytes,1,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + Data *anypb.Any `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *Response) Reset() { + *x = Response{} + if protoimpl.UnsafeEnabled { + mi := &file_message_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Response) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Response) ProtoMessage() {} + +func (x *Response) ProtoReflect() protoreflect.Message { + mi := &file_message_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Response.ProtoReflect.Descriptor instead. +func (*Response) Descriptor() ([]byte, []int) { + return file_message_proto_rawDescGZIP(), []int{0} +} + +func (x *Response) GetMetadata() map[string]string { + if x != nil { + return x.Metadata + } + return nil +} + +func (x *Response) GetData() *anypb.Any { + if x != nil { + return x.Data + } + return nil +} + +var File_message_proto protoreflect.FileDescriptor + +var file_message_proto_rawDesc = []byte{ + 0x0a, 0x0d, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x19, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x61, 0x6e, 0x79, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x22, 0xae, 0x01, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x3b, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, + 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x28, 0x0a, + 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, + 0x79, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x3a, 0x02, 0x38, 0x01, 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x3b, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_message_proto_rawDescOnce sync.Once + file_message_proto_rawDescData = file_message_proto_rawDesc +) + +func file_message_proto_rawDescGZIP() []byte { + file_message_proto_rawDescOnce.Do(func() { + file_message_proto_rawDescData = protoimpl.X.CompressGZIP(file_message_proto_rawDescData) + }) + return file_message_proto_rawDescData +} + +var file_message_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_message_proto_goTypes = []any{ + (*Response)(nil), // 0: message.Response + nil, // 1: message.Response.MetadataEntry + (*anypb.Any)(nil), // 2: google.protobuf.Any +} +var file_message_proto_depIdxs = []int32{ + 1, // 0: message.Response.metadata:type_name -> message.Response.MetadataEntry + 2, // 1: message.Response.data:type_name -> google.protobuf.Any + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_message_proto_init() } +func file_message_proto_init() { + if File_message_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_message_proto_msgTypes[0].Exporter = func(v any, i int) any { + switch v := v.(*Response); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_message_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_message_proto_goTypes, + DependencyIndexes: file_message_proto_depIdxs, + MessageInfos: file_message_proto_msgTypes, + }.Build() + File_message_proto = out.File + file_message_proto_rawDesc = nil + file_message_proto_goTypes = nil + file_message_proto_depIdxs = nil +} diff --git a/server/drpc/message/message.proto b/server/drpc/message/message.proto new file mode 100644 index 00000000..a8f27cc5 --- /dev/null +++ b/server/drpc/message/message.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package message; + +import "google/protobuf/any.proto"; + +option go_package = ".;message"; + +message Response { + map metadata = 1; + google.protobuf.Any data = 2; +} \ No newline at end of file diff --git a/server/drpc/mux.go b/server/drpc/mux.go new file mode 100644 index 00000000..c3f59471 --- /dev/null +++ b/server/drpc/mux.go @@ -0,0 +1,98 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// Package drpc provides a drpc mux that handles orb middleware. +package drpc + +import ( + "reflect" + + "github.com/zeebo/errs" + + "storj.io/drpc" +) + +// Mux is an implementation of Handler to serve drpc connections to the +// appropriate Receivers registered by Descriptions. +type Mux struct { + orbSrv *Server + + rpcs map[string]rpcData +} + +func newMux(srv *Server) *Mux { + return &Mux{ + orbSrv: srv, + rpcs: make(map[string]rpcData), + } +} + +//nolint:gochecknoglobals +var ( + streamType = reflect.TypeOf((*drpc.Stream)(nil)).Elem() + messageType = reflect.TypeOf((*drpc.Message)(nil)).Elem() +) + +type rpcData struct { + srv interface{} + enc drpc.Encoding + receiver drpc.Receiver + in1 reflect.Type + in2 reflect.Type + unitary bool +} + +// Register associates the RPCs described by the description in the server. +// It returns an error if there was a problem registering it. +func (m *Mux) Register(srv interface{}, desc drpc.Description) error { + n := desc.NumMethods() + for i := 0; i < n; i++ { + rpc, enc, receiver, method, ok := desc.Method(i) + if !ok { + return errs.New("Description returned invalid method for index %d", i) + } + + if err := m.registerOne(srv, rpc, enc, receiver, method); err != nil { + return err + } + } + + return nil +} + +// registerOne does the work to register a single rpc. +func (m *Mux) registerOne(srv interface{}, rpc string, enc drpc.Encoding, receiver drpc.Receiver, method interface{}) error { + data := rpcData{srv: srv, enc: enc, receiver: receiver} + + switch mt := reflect.TypeOf(method); { + // unitary input, unitary output + case mt.NumOut() == 2: + data.unitary = true + data.in1 = mt.In(2) + + if !data.in1.Implements(messageType) { + return errs.New("input argument not a drpc message: %v", data.in1) + } + + // unitary input, stream output + case mt.NumIn() == 3: + data.in1 = mt.In(1) + if !data.in1.Implements(messageType) { + return errs.New("input argument not a drpc message: %v", data.in1) + } + + data.in2 = streamType + + // stream input + case mt.NumIn() == 2: + data.in1 = streamType + + // code gen bug? + default: + return errs.New("unknown method type: %v", mt) + } + + m.rpcs[rpc] = data + + return nil +} diff --git a/server/drpc/plugin.go b/server/drpc/plugin.go index ad49420e..c1df7612 100644 --- a/server/drpc/plugin.go +++ b/server/drpc/plugin.go @@ -23,7 +23,7 @@ func pluginProvider( return nil, ErrInvalidConfigType } - return ProvideServer(service, logger, reg, *cfg) + return Provide(service, logger, reg, *cfg) } func newDefaultConfig() server.EntrypointConfig { diff --git a/server/grpc/error.go b/server/grpc/error.go new file mode 100644 index 00000000..0c53b9fa --- /dev/null +++ b/server/grpc/error.go @@ -0,0 +1,40 @@ +// Package grpc is the grpc transport for plugins/client/orb. +package grpc + +import ( + "net/http" + + "google.golang.org/grpc/codes" +) + +// codeToHTTPStatus maps gRPC codes to HTTP statuses. +// Based on https://cloud.google.com/apis/design/errors +// +// Copied from: https://github.com/luci/luci-go/blob/main/grpc/grpcutil/errors.go#L118 (Apache 2.0). +// +//nolint:gochecknoglobals +var httpStatusToCode = map[int]codes.Code{ + http.StatusOK: codes.OK, + 499: codes.Canceled, + http.StatusBadRequest: codes.InvalidArgument, + http.StatusInternalServerError: codes.Internal, + http.StatusGatewayTimeout: codes.DeadlineExceeded, + http.StatusNotFound: codes.NotFound, + http.StatusConflict: codes.AlreadyExists, + http.StatusForbidden: codes.PermissionDenied, + http.StatusUnauthorized: codes.Unauthenticated, + http.StatusTooManyRequests: codes.ResourceExhausted, + http.StatusNotImplemented: codes.Unimplemented, + http.StatusServiceUnavailable: codes.Unavailable, +} + +// HTTPStatusToCode maps HTTP status codes to gRPC codes. +// +// Falls back to codes.Internal if the code is unrecognized. +func HTTPStatusToCode(code int) codes.Code { + if status, ok := httpStatusToCode[code]; ok { + return status + } + + return codes.Internal +} diff --git a/server/grpc/grpc.go b/server/grpc/grpc.go index b798ae58..96094a31 100644 --- a/server/grpc/grpc.go +++ b/server/grpc/grpc.go @@ -51,8 +51,8 @@ type ServerGRPC struct { started bool } -// ProvideServerGRPC creates a gRPC server by options. -func ProvideServerGRPC( +// Provide creates a gRPC server by options. +func Provide( _ types.ServiceName, logger log.Logger, reg registry.Type, diff --git a/server/grpc/interceptor.go b/server/grpc/interceptor.go index 467335ae..504a954b 100644 --- a/server/grpc/interceptor.go +++ b/server/grpc/interceptor.go @@ -2,23 +2,50 @@ package grpc import ( "context" + "slices" + "strings" + "github.com/go-orb/go-orb/util/metadata" + "github.com/go-orb/go-orb/util/orberrors" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + gmetadata "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) +//nolint:gochecknoglobals +var stdHeaders = []string{"content-type", "user-agent"} + func (s *ServerGRPC) unaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + ctx = metadata.EnsureIncoming(ctx) + ctx = metadata.EnsureOutgoing(ctx) + + reqMd, _ := metadata.IncomingFrom(ctx) + + // Copy incoming metadata from grpc to orb. + if gReqMd, ok := gmetadata.FromIncomingContext(ctx); ok { + for k, v := range gReqMd { + if slices.Contains(stdHeaders, k) { + continue + } + + reqMd[k] = v[0] + } + } + + fmSplit := strings.Split(info.FullMethod, "/") + if len(fmSplit) >= 3 { + reqMd[metadata.Service] = fmSplit[1] + reqMd[metadata.Method] = fmSplit[2] + } + var cancel func() if s.config.Timeout > 0 { ctx, cancel = context.WithTimeout(ctx, s.config.Timeout) defer cancel() } - // Directly execute handler if no middleware is defined. - if s.unaryMiddleware == 0 { - return handler(ctx, req) - } - h := func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { return handler(ctx, req) } @@ -29,7 +56,33 @@ func (s *ServerGRPC) unaryServerInterceptor() grpc.UnaryServerInterceptor { h = chainUnaryInterceptors(next) } - return h(ctx, req, info, handler) + result, err := h(ctx, req, info, handler) + + outMd, _ := metadata.OutgoingFrom(ctx) + if len(outMd) > 0 { + gOutMd := make(gmetadata.MD) + + for k, v := range outMd { + gOutMd[k] = []string{v} + } + + if err := grpc.SendHeader(ctx, gOutMd); err != nil { + return nil, status.Errorf(codes.Internal, "internal error while sending headers") + } + } + + if err != nil { + oErr := orberrors.From(err) + gCode := HTTPStatusToCode(oErr.Code) + + if oErr.Wrapped != nil { + return nil, status.Errorf(gCode, "%s: %s", oErr.Message, oErr.Wrapped.Error()) + } + + return nil, status.Errorf(gCode, "%s", oErr.Message) + } + + return result, nil } } diff --git a/server/grpc/plugin.go b/server/grpc/plugin.go index 52b11b81..93e81752 100644 --- a/server/grpc/plugin.go +++ b/server/grpc/plugin.go @@ -44,7 +44,7 @@ func pluginProvider( return nil, ErrInvalidConfigType } - return ProvideServerGRPC(service, logger, reg, *cfg) + return Provide(service, logger, reg, *cfg) } func newDefaultConfig() server.EntrypointConfig { diff --git a/server/grpc/tests/grpc_test.go b/server/grpc/tests/grpc_test.go index 436a61d9..41386c61 100644 --- a/server/grpc/tests/grpc_test.go +++ b/server/grpc/tests/grpc_test.go @@ -136,13 +136,13 @@ func TestGrpcIntegration(t *testing.T) { name := types.ServiceName("com.example.test") version := types.ServiceVersion("v1.0.0") - logger, err := log.ProvideLogger(name, nil) + logger, err := log.Provide(name, nil) require.NoError(t, err, "failed to setup logger") - reg, err := registry.ProvideRegistry(name, version, nil, logger) + reg, err := registry.Provide(name, version, nil, logger) require.NoError(t, err, "failed to setup the registry") - srv, err := server.ProvideServer(name, nil, logger, reg, + srv, err := server.Provide(name, nil, logger, reg, mgrpc.WithDefaults( mgrpc.WithGRPCReflection(false), mgrpc.WithInsecure(true), @@ -222,13 +222,13 @@ func TestServerFileConfig(t *testing.T) { config, err := config.Read([]*url.URL{fURL}, nil) require.NoError(t, err, "failed to read file config") - logger, err := log.ProvideLogger(name, nil) + logger, err := log.Provide(name, nil) require.NoError(t, err, "failed to setup logger") - reg, err := registry.ProvideRegistry(name, version, nil, logger) + reg, err := registry.Provide(name, version, nil, logger) require.NoError(t, err, "failed to setup the registry") - srv, err := server.ProvideServer(name, config, logger, reg, + srv, err := server.Provide(name, config, logger, reg, mgrpc.WithEntrypoint( mgrpc.WithName("static-ep-1"), mgrpc.WithAddress(":48081"), diff --git a/server/grpc/tests/util/grpc/grpc.go b/server/grpc/tests/util/grpc/grpc.go index 383b41cd..731cfd3b 100644 --- a/server/grpc/tests/util/grpc/grpc.go +++ b/server/grpc/tests/util/grpc/grpc.go @@ -30,12 +30,12 @@ func SetupServer(opts ...mgrpc.Option) (*mgrpc.ServerGRPC, func(t *testing.T), e return nil, nil, fmt.Errorf("setup logger: %w", err) } - reg, err := registry.ProvideRegistry("app", "v1.0.0", nil, logger) + reg, err := registry.Provide("app", "v1.0.0", nil, logger) if err != nil { return nil, nil, fmt.Errorf("setup registry: %w", err) } - srv, err := mgrpc.ProvideServerGRPC("", logger, reg, *cfg, opts...) + srv, err := mgrpc.Provide("", logger, reg, *cfg, opts...) if err != nil { return nil, nil, fmt.Errorf("setup gRPC server: %w", err) } diff --git a/server/hertz/config.go b/server/hertz/config.go index 145528f2..762d7631 100644 --- a/server/hertz/config.go +++ b/server/hertz/config.go @@ -6,9 +6,8 @@ import ( "fmt" "time" - "log/slog" - "github.com/go-orb/go-orb/codecs" + "github.com/go-orb/go-orb/log" "github.com/go-orb/go-orb/server" "github.com/go-orb/go-orb/util/slicemap" mtls "github.com/go-orb/go-orb/util/tls" @@ -174,12 +173,12 @@ type Config struct { // Handlers global, and setting them explicitly in the config. HandlerRegistrations server.HandlerRegistrations `json:"handlers" yaml:"handlers"` + // Middlewares is a list of middleware to use. + Middlewares []string `json:"middlewares" yaml:"middlewares"` + // Logger allows you to dynamically change the log level and plugin for a // specific entrypoint. - Logger struct { - Level slog.Level `json:"level,omitempty" yaml:"level,omitempty"` // TODO(davincible): change with custom level - Plugin string `json:"plugin,omitempty" yaml:"plugin,omitempty"` - } `json:"logger" yaml:"logger"` + Logger log.Config `json:"logger" yaml:"logger"` } // NewConfig will create a new default config for the entrypoint. @@ -198,6 +197,7 @@ func NewConfig(options ...Option) *Config { IdleTimeout: DefaultIdleTimeout, StopTimeout: DefaultStopTimeout, HandlerRegistrations: make(server.HandlerRegistrations), + Middlewares: []string{}, } cfg.ApplyOptions(options...) @@ -370,7 +370,7 @@ func WithRegistration(name string, registration server.RegistrationFunc) Option } // WithLogLevel changes the log level from the inherited logger. -func WithLogLevel(level slog.Level) Option { +func WithLogLevel(level string) Option { return func(c *Config) { c.Logger.Level = level } @@ -418,3 +418,14 @@ func WithEntrypoint(options ...Option) server.Option { } } } + +// WithMiddleware appends middlewares to the server. +// You can use any standard Go HTTP middleware. +// +// Each middlware is uniquely identified with a name. The name provided here +// can be used to dynamically add middlware to an entrypoint in a config. +func WithMiddleware(middlewares ...string) Option { + return func(c *Config) { + c.Middlewares = append(c.Middlewares, middlewares...) + } +} diff --git a/server/hertz/handler.go b/server/hertz/handler.go index 8d5f9d4c..7ebb0498 100644 --- a/server/hertz/handler.go +++ b/server/hertz/handler.go @@ -3,15 +3,14 @@ package hertz import ( "context" "errors" - "strings" + "slices" "github.com/cloudwego/hertz/pkg/app" "github.com/go-orb/go-orb/util/metadata" "github.com/go-orb/go-orb/util/orberrors" ) -// orbHeader is the prefix for every orb HTTP header. -const orbHeader = "__orb-" +var stdHeaders = []string{"Accept", "Accept-Encoding", "Content-Length", "Content-Type", "User-Agent"} //nolint:gochecknoglobals // Errors. var ( @@ -21,12 +20,14 @@ var ( // NewGRPCHandler wraps a gRPC function with a Hertz handler. func NewGRPCHandler[Tin any, Tout any]( srv *Server, - f func(context.Context, *Tin) (*Tout, error), + fHandler func(context.Context, *Tin) (*Tout, error), + service string, + method string, ) func(c context.Context, ctx *app.RequestContext) { return func(ctx context.Context, apCtx *app.RequestContext) { - in := new(Tin) + request := new(Tin) - if _, err := srv.decodeBody(apCtx, in); err != nil { + if _, err := srv.decodeBody(apCtx, request); err != nil { srv.Logger.Error("failed to decode body", "error", err) WriteError(apCtx, err) @@ -34,19 +35,31 @@ func NewGRPCHandler[Tin any, Tout any]( } // Copy metadata from req Headers into the req.Context. - reqMd := make(metadata.Metadata) + ctx = metadata.EnsureIncoming(ctx) + ctx = metadata.EnsureOutgoing(ctx) + reqMd, _ := metadata.IncomingFrom(ctx) apCtx.VisitAllHeaders(func(k, v []byte) { sk := string(k) - if !strings.HasPrefix(strings.ToLower(sk), orbHeader) { + if slices.Contains(stdHeaders, sk) { return } - sk = sk[len(orbHeader):] reqMd[sk] = string(v) }) - out, err := f(reqMd.To(ctx), in) + reqMd[metadata.Service] = service + reqMd[metadata.Method] = method + + // Apply middleware. + h := func(ctx context.Context, req any) (any, error) { + return fHandler(ctx, req.(*Tin)) + } + for _, m := range srv.middlewares { + h = m.Call(h) + } + + out, err := h(ctx, request) if err != nil { srv.Logger.Error("RPC request failed", "error", err) WriteError(apCtx, err) @@ -54,9 +67,11 @@ func NewGRPCHandler[Tin any, Tout any]( return } - // Write back metadata to headers. - for k, v := range reqMd { - apCtx.Header(orbHeader+k, v) + // Write outgoing metadata. + if md, ok := metadata.OutgoingFrom(ctx); ok { + for k, v := range md { + apCtx.Header(k, v) + } } if err := srv.encodeBody(apCtx, out); err != nil { diff --git a/server/hertz/hertz.go b/server/hertz/hertz.go index fb39168b..377e9850 100644 --- a/server/hertz/hertz.go +++ b/server/hertz/hertz.go @@ -35,6 +35,8 @@ type Server struct { Logger log.Logger Registry registry.Type + middlewares []orbserver.Middleware + hServer *server.Hertz // entrypointID is the entrypointID (uuid) of this entrypoint in the registry. @@ -210,10 +212,10 @@ func (s *Server) registryDeregister() error { return s.Registry.Deregister(rService) } -// ProvideServer creates a new entrypoint for a single address. You can create +// Provide creates a new entrypoint for a single address. You can create // multiple entrypoints for multiple addresses and ports. One entrypoint // can serve a HTTP1 and HTTP2/H2C server. -func ProvideServer( +func Provide( _ types.ServiceName, logger log.Logger, reg registry.Type, @@ -240,11 +242,22 @@ func ProvideServer( logger = logger.With(slog.String("entrypoint", cfg.Name)) + mws := []orbserver.Middleware{} + + for _, m := range cfg.Middlewares { + if mw, ok := orbserver.Middlewares.Get(m); ok { + mws = append(mws, mw) + } else { + logger.Error("unknown middleware given", "middleware", m) + } + } + entrypoint := Server{ - Config: cfg, - Logger: logger, - Registry: reg, - codecs: codecs, + Config: cfg, + Logger: logger, + Registry: reg, + middlewares: mws, + codecs: codecs, } return &entrypoint, nil diff --git a/server/hertz/plugin.go b/server/hertz/plugin.go index e69dc264..441fc547 100644 --- a/server/hertz/plugin.go +++ b/server/hertz/plugin.go @@ -23,7 +23,7 @@ func pluginProvider( return nil, ErrInvalidConfigType } - return ProvideServer(service, logger, reg, *cfg) + return Provide(service, logger, reg, *cfg) } func newDefaultConfig() server.EntrypointConfig { diff --git a/server/http/config.go b/server/http/config.go index 58d1628c..9ad63299 100644 --- a/server/http/config.go +++ b/server/http/config.go @@ -4,12 +4,10 @@ import ( "crypto/tls" "errors" "fmt" - "net/http" "time" - "log/slog" - "github.com/go-orb/go-orb/codecs" + "github.com/go-orb/go-orb/log" "github.com/go-orb/go-orb/server" "github.com/go-orb/go-orb/util/slicemap" mtls "github.com/go-orb/go-orb/util/tls" @@ -200,15 +198,12 @@ type Config struct { // Handlers global, and setting them explicitly in the config. HandlerRegistrations server.HandlerRegistrations `json:"handlers" yaml:"handlers"` - // Middleware is a list of middleware to use. - Middleware router.Middlewares `json:"middleware" yaml:"middleware"` + // Middlewares is a list of middleware to use. + Middlewares []string `json:"middlewares" yaml:"middlewares"` // Logger allows you to dynamically change the log level and plugin for a // specific entrypoint. - Logger struct { - Level slog.Level `json:"level,omitempty" yaml:"level,omitempty"` // TODO(davincible): change with custom level - Plugin string `json:"plugin,omitempty" yaml:"plugin,omitempty"` - } `json:"logger" yaml:"logger"` + Logger log.Config `json:"logger" yaml:"logger"` } // NewConfig will create a new default config for the entrypoint. @@ -229,7 +224,7 @@ func NewConfig(options ...Option) *Config { WriteTimeout: DefaultWriteTimeout, IdleTimeout: DefaultIdleTimeout, HandlerRegistrations: make(server.HandlerRegistrations), - Middleware: make(router.Middlewares), + Middlewares: []string{}, } cfg.ApplyOptions(options...) @@ -447,16 +442,14 @@ func WithRegistration(name string, registration server.RegistrationFunc) Option // // Each middlware is uniquely identified with a name. The name provided here // can be used to dynamically add middlware to an entrypoint in a config. -func WithMiddleware(name string, middleware func(http.Handler) http.Handler) Option { - router.Middleware.Set(name, middleware) - +func WithMiddleware(middlewares ...string) Option { return func(c *Config) { - c.Middleware[name] = middleware + c.Middlewares = append(c.Middlewares, middlewares...) } } // WithLogLevel changes the log level from the inherited logger. -func WithLogLevel(level slog.Level) Option { +func WithLogLevel(level string) Option { return func(c *Config) { c.Logger.Level = level } diff --git a/server/http/entrypoint.go b/server/http/entrypoint.go index 9469e307..41d90706 100644 --- a/server/http/entrypoint.go +++ b/server/http/entrypoint.go @@ -47,6 +47,8 @@ type ServerHTTP struct { Logger log.Logger Registry registry.Type + middlewares []server.Middleware + // entrypointID is the entrypointID (uuid) of this entrypoint in the registry. entrypointID string @@ -69,11 +71,11 @@ type ServerHTTP struct { activeRequests int64 // accessed atomically } -// ProvideServerHTTP creates a new entrypoint for a single address. You can create +// Provide creates a new entrypoint for a single address. You can create // multiple entrypoints for multiple addresses and ports. One entrypoint // can serve a HTTP1, HTTP2 and HTTP3 server. If you enable HTTP3 it will listen // on both TCP and UDP on the same port. -func ProvideServerHTTP( +func Provide( _ types.ServiceName, logger log.Logger, reg registry.Type, @@ -105,12 +107,23 @@ func ProvideServerHTTP( logger = logger.With(slog.String("entrypoint", cfg.Name)) + mws := []server.Middleware{} + + for _, m := range cfg.Middlewares { + if mw, ok := server.Middlewares.Get(m); ok { + mws = append(mws, mw) + } else { + logger.Error("unknown middleware given", "middleware", m) + } + } + entrypoint := ServerHTTP{ - Config: cfg, - Logger: logger, - Registry: reg, - codecs: codecs, - router: router, + Config: cfg, + Logger: logger, + Registry: reg, + middlewares: mws, + codecs: codecs, + router: router, } entrypoint.Config.TLS, err = entrypoint.setupTLS() @@ -140,10 +153,6 @@ func (s *ServerHTTP) Start() error { s.Logger.Info("Starting", "address", s.Config.Address) - for _, middleware := range s.Config.Middleware { - s.router.Use(middleware) - } - for _, h := range s.Config.HandlerRegistrations { h(s) } diff --git a/server/http/handler.go b/server/http/handler.go index 97fb6e9b..b62d1b78 100644 --- a/server/http/handler.go +++ b/server/http/handler.go @@ -5,15 +5,14 @@ import ( "errors" "fmt" "net/http" + "slices" "strconv" - "strings" "github.com/go-orb/go-orb/util/metadata" "github.com/go-orb/go-orb/util/orberrors" ) -// orbHeader is the prefix for every orb HTTP header. -const orbHeader = "__orb-" +var stdHeaders = []string{"Accept", "Accept-Encoding", "Content-Length", "Content-Type", "User-Agent"} //nolint:gochecknoglobals // Errors. var ( @@ -21,27 +20,32 @@ var ( ) // NewGRPCHandler will wrap a gRPC function with a HTTP handler. -func NewGRPCHandler[Tin any, Tout any](srv *ServerHTTP, fHandler func(context.Context, *Tin) (*Tout, error)) http.HandlerFunc { +func NewGRPCHandler[Tin any, Tout any]( + srv *ServerHTTP, + fHandler func(context.Context, *Tin) (*Tout, error), + service string, + method string, +) http.HandlerFunc { return func(resp http.ResponseWriter, req *http.Request) { inBody := new(Tin) if _, err := srv.decodeBody(resp, req, inBody); err != nil { - srv.Logger.Error("failed to decode body", "error", err) + srv.Logger.Error("failed to decode request body", "error", err) WriteError(resp, orberrors.ErrBadRequest.Wrap(err)) return } // Copy metadata from req Headers into the req.Context. - reqMd := make(metadata.Metadata) + ctx := metadata.EnsureIncoming(req.Context()) + ctx = metadata.EnsureOutgoing(ctx) + reqMd, _ := metadata.IncomingFrom(ctx) for k, v := range req.Header { - if !strings.HasPrefix(strings.ToLower(k), orbHeader) { + if slices.Contains(stdHeaders, k) { continue } - k = k[len(orbHeader):] - if len(v) == 1 { reqMd[k] = v[0] } else { @@ -52,7 +56,19 @@ func NewGRPCHandler[Tin any, Tout any](srv *ServerHTTP, fHandler func(context.Co } } - out, err := fHandler(reqMd.To(req.Context()), inBody) + reqMd[metadata.Service] = service + reqMd[metadata.Method] = method + + // Apply middleware. + h := func(ctx context.Context, req any) (any, error) { + return fHandler(ctx, req.(*Tin)) + } + for _, m := range srv.middlewares { + h = m.Call(h) + } + + // The actual call. + out, err := h(ctx, inBody) if err != nil { srv.Logger.Error("RPC request failed", "error", err) WriteError(resp, err) @@ -60,13 +76,15 @@ func NewGRPCHandler[Tin any, Tout any](srv *ServerHTTP, fHandler func(context.Co return } - // Write back metadata to headers. - for k, v := range reqMd { - resp.Header().Set(orbHeader+k, v) + // Write outgoing metadata. + if md, ok := metadata.OutgoingFrom(ctx); ok { + for k, v := range md { + resp.Header().Set(k, v) + } } if err := srv.encodeBody(resp, req, out); err != nil { - srv.Logger.Error("failed to encode body", "error", err) + srv.Logger.Error("failed to encode response body", "error", err) WriteError(resp, err) return diff --git a/server/http/plugin.go b/server/http/plugin.go index a2cf9ab9..d4d2be10 100644 --- a/server/http/plugin.go +++ b/server/http/plugin.go @@ -23,7 +23,7 @@ func pluginProvider( return nil, ErrInvalidConfigType } - return ProvideServerHTTP(service, logger, reg, *cfg) + return Provide(service, logger, reg, *cfg) } func newDefaultConfig() server.EntrypointConfig { diff --git a/server/http/tests/http_test.go b/server/http/tests/http_test.go index 800ab159..0acbc558 100644 --- a/server/http/tests/http_test.go +++ b/server/http/tests/http_test.go @@ -348,15 +348,15 @@ func TestServerIntegration(t *testing.T) { name := types.ServiceName("com.example.test") version := types.ServiceVersion("v1.0.0") - logger, err := log.ProvideLogger(name, nil) + logger, err := log.Provide(name, nil) require.NoError(t, err, "failed to setup the logger") - reg, err := registry.ProvideRegistry(name, version, nil, logger) + reg, err := registry.Provide(name, version, nil, logger) require.NoError(t, err, "failed to setup the registry") h := new(handler.EchoHandler) - srv, err := server.ProvideServer(name, nil, logger, reg, + srv, err := server.Provide(name, nil, logger, reg, mhttp.WithEntrypoint( mhttp.WithName("test-ep-1"), mhttp.WithAddress(":48081"), @@ -419,14 +419,14 @@ func TestServerFileConfig(t *testing.T) { config, err := config.Read([]*url.URL{fURL}, nil) require.NoError(t, err, "failed to read file config") - logger, err := log.ProvideLogger(name, nil) + logger, err := log.Provide(name, nil) require.NoError(t, err, "failed to setup the logger") - reg, err := registry.ProvideRegistry(name, version, nil, logger) + reg, err := registry.Provide(name, version, nil, logger) require.NoError(t, err, "failed to setup the registry") h := new(handler.EchoHandler) - srv, err := server.ProvideServer(name, config, logger, reg, + srv, err := server.Provide(name, config, logger, reg, // TODO(davincible): test defaults mhttp.WithEntrypoint( mhttp.WithName("static-ep-1"), @@ -436,7 +436,6 @@ func TestServerFileConfig(t *testing.T) { ), mhttp.WithEntrypoint( mhttp.WithName("test-ep-5"), - mhttp.WithMiddleware("middleware-3", func(h http.Handler) http.Handler { return h }), ), ) require.NoError(t, err, "failed to setup server") @@ -482,7 +481,6 @@ func TestServerFileConfig(t *testing.T) { ep = e.(*mhttp.ServerHTTP) //nolint:errcheck require.True(t, strings.HasSuffix(ep.Config.Address, ":4516")) require.Len(t, ep.Config.HandlerRegistrations, 3, "Registration len") - require.Len(t, ep.Config.Middleware, 4, "Middleware len") makeRequests(t, "https://"+e.Address(), thttp.TypeHTTP2) require.NoError(t, srv.Stop(context.Background()), "failed to start server") @@ -675,12 +673,12 @@ func setupServer(tb testing.TB, nolog bool, opts ...mhttp.Option) (*mhttp.Server cancel := func() {} - logger, err := log.ProvideLogger(name, nil, lopts...) + logger, err := log.Provide(name, nil, lopts...) if err != nil { return nil, cancel, fmt.Errorf("failed to setup logger: %w", err) } - reg, err := registry.ProvideRegistry("app", "v1.0.0", nil, logger) + reg, err := registry.Provide("app", "v1.0.0", nil, logger) if err != nil { return nil, nil, fmt.Errorf("setup registry: %w", err) } @@ -692,7 +690,7 @@ func setupServer(tb testing.TB, nolog bool, opts ...mhttp.Option) (*mhttp.Server cfg := mhttp.NewConfig(opts...) - server, err := mhttp.ProvideServerHTTP(name, logger, reg, *cfg) + server, err := mhttp.Provide(name, logger, reg, *cfg) if err != nil { return nil, cancel, fmt.Errorf("failed to provide http server: %w", err) } diff --git a/server/http/tests/proto/echo_http.micro.pb.go b/server/http/tests/proto/echo_http.micro.pb.go index 3c14db3d..307f8ce9 100644 --- a/server/http/tests/proto/echo_http.micro.pb.go +++ b/server/http/tests/proto/echo_http.micro.pb.go @@ -19,8 +19,8 @@ import ( // RegisterStreamsHTTPHandler registers the service to an HTTP server. func RegisterStreamsHTTPHandler(srv *mhttp.ServerHTTP, handler StreamsServer) { r := srv.Router() - r.Get("/echo", mhttp.NewGRPCHandler(srv, handler.Call)) - r.Post("/echo", mhttp.NewGRPCHandler(srv, handler.Call)) + r.Get("/echo", mhttp.NewGRPCHandler(srv, handler.Call, "echo", "Call")) + r.Post("/echo", mhttp.NewGRPCHandler(srv, handler.Call, "echo", "Call")) } // RegisterStreamsHandler will return a registration function that can be diff --git a/server/tests/server_test.go b/server/tests/server_test.go index 0a994172..ab2be4b2 100644 --- a/server/tests/server_test.go +++ b/server/tests/server_test.go @@ -135,10 +135,10 @@ func TestMockConfigFile(t *testing.T) { logger, err := log.New() require.NoError(t, err, "failed to setup logger") - reg, err := registry.ProvideRegistry(service, version, nil, logger) + reg, err := registry.Provide(service, version, nil, logger) require.NoError(t, err, "failed to setup the registry") - srv, err := server.ProvideServer(service, data, logger, reg, + srv, err := server.Provide(service, data, logger, reg, WithMockDefaults(WithTest(t)), WithMockEntrypoint( WithMockName("static-ep-1"), @@ -193,10 +193,10 @@ func TestMockConfigFile(t *testing.T) { logger, err = log.New() require.NoError(t, err, "failed to setup the logger") - reg, err = registry.ProvideRegistry(service, version, nil, logger) + reg, err = registry.Provide(service, version, nil, logger) require.NoError(t, err, "failed to setup the registry") - srv, err = server.ProvideServer(service, data, logger, reg, WithMockDefaults(WithTest(t))) + srv, err = server.Provide(service, data, logger, reg, WithMockDefaults(WithTest(t))) require.NoError(t, err, "failed to setup server") require.NoError(t, srv.Start(), "failed to start server") @@ -216,7 +216,7 @@ func TestMockConfigFile(t *testing.T) { logger, err = log.New() require.NoError(t, err, "failed to setup logger") - srv, err = server.ProvideServer(service, data, logger, reg, WithMockDefaults(WithTest(t))) + srv, err = server.Provide(service, data, logger, reg, WithMockDefaults(WithTest(t))) t.Logf("expected error: %v", err) require.Error(t, err, "should fail to setup server for "+service) } @@ -279,12 +279,12 @@ func setupServer(opts ...server.Option) (server.Server, error) { return server.Server{}, fmt.Errorf("failed to setup logger: %w", err) } - reg, err := registry.ProvideRegistry(service, version, nil, logger) + reg, err := registry.Provide(service, version, nil, logger) if err != nil { return server.Server{}, fmt.Errorf("failed to setup registry: %w", err) } - srv, err := server.ProvideServer(service, nil, logger, reg, opts...) + srv, err := server.Provide(service, nil, logger, reg, opts...) if err != nil { return srv, fmt.Errorf("failed to setup server: %w", err) }