diff --git a/grpc/client.go b/grpc/client.go index 8a56939..ef46ffc 100644 --- a/grpc/client.go +++ b/grpc/client.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" @@ -257,6 +258,9 @@ func (c *Client) Connect(addr string, params goja.Value) (bool, error) { if !p.UseReflectionProtocol { return true, nil } + + ctx = metadata.NewOutgoingContext(ctx, p.ReflectionMetadata) + fdset, err := c.conn.Reflect(ctx) if err != nil { return false, err diff --git a/grpc/client_test.go b/grpc/client_test.go index 4f8585c..f76429f 100644 --- a/grpc/client_test.go +++ b/grpc/client_test.go @@ -16,6 +16,7 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/known/wrapperspb" + "gopkg.in/guregu/null.v3" "github.com/golang/protobuf/ptypes/any" "github.com/golang/protobuf/ptypes/wrappers" @@ -1128,6 +1129,49 @@ func TestDebugStat(t *testing.T) { } } +func TestClientConnectionReflectMetadata(t *testing.T) { + t.Parallel() + + ts := newTestState(t) + + reflection.Register(ts.httpBin.ServerGRPC) + + initString := codeBlock{ + code: `var client = new grpc.Client();`, + } + vuString := codeBlock{ + code: `client.connect("GRPCBIN_ADDR", {reflect: true, reflectMetadata: {"x-test": "custom-header-for-reflection"}})`, + } + + val, err := ts.Run(initString.code) + assertResponse(t, initString, err, val, ts) + + ts.ToVUContext() + + // this should trigger logging of the outgoing gRPC metadata + ts.VU.State().Options.HTTPDebug = null.NewString("full", true) + + val, err = ts.Run(vuString.code) + assertResponse(t, vuString, err, val, ts) + + entries := ts.loggerHook.Drain() + + // since we enable debug logging, we should see the metadata in the logs + foundReflectionCall := false + for _, entry := range entries { + if strings.Contains(entry.Message, "ServerReflection/ServerReflectionInfo") { + foundReflectionCall = true + + // check that the metadata is present + assert.Contains(t, entry.Message, "x-test: custom-header-for-reflection") + // check that user-agent header is present + assert.Contains(t, entry.Message, "user-agent: k6-test") + } + } + + assert.True(t, foundReflectionCall, "expected to find a reflection call in the logs, but didn't") +} + func TestClientLoadProto(t *testing.T) { t.Parallel() diff --git a/grpc/params.go b/grpc/params.go index 04983ad..d9fb913 100644 --- a/grpc/params.go +++ b/grpc/params.go @@ -123,19 +123,21 @@ func (p *callParams) SetSystemTags(state *lib.State, addr string, methodName str type connectParams struct { IsPlaintext bool UseReflectionProtocol bool + ReflectionMetadata metadata.MD Timeout time.Duration MaxReceiveSize int64 MaxSendSize int64 TLS map[string]interface{} } -func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) { +func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) { //nolint:gocognit result := &connectParams{ IsPlaintext: false, UseReflectionProtocol: false, Timeout: time.Minute, MaxReceiveSize: 0, MaxSendSize: 0, + ReflectionMetadata: metadata.New(nil), } if common.IsNullish(input) { @@ -167,6 +169,13 @@ func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) { if !ok { return result, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v) } + case "reflectMetadata": + md, err := newMetadata(params.Get(k)) + if err != nil { + return result, fmt.Errorf("invalid reflectMetadata param: %w", err) + } + + result.ReflectionMetadata = md case "maxReceiveSize": var ok bool result.MaxReceiveSize, ok = v.(int64) diff --git a/grpc/teststate_test.go b/grpc/teststate_test.go index f3bc669..fbc8ae4 100644 --- a/grpc/teststate_test.go +++ b/grpc/teststate_test.go @@ -15,6 +15,7 @@ import ( "go.k6.io/k6/js/modulestest" "go.k6.io/k6/lib" "go.k6.io/k6/lib/fsext" + "go.k6.io/k6/lib/testutils" "go.k6.io/k6/lib/testutils/httpmultibin" "go.k6.io/k6/metrics" "gopkg.in/guregu/null.v3" @@ -78,6 +79,7 @@ type testState struct { httpBin *httpmultibin.HTTPMultiBin samples chan metrics.SampleContainer logger logrus.FieldLogger + loggerHook *testutils.SimpleLogrusHook callRecorder *callRecorder } @@ -114,6 +116,9 @@ func newTestState(t *testing.T) testState { logger.SetLevel(logrus.InfoLevel) logger.Out = io.Discard + hook := testutils.NewLogHook() + logger.AddHook(hook) + recorder := &callRecorder{ calls: make([]string, 0), } @@ -123,6 +128,7 @@ func newTestState(t *testing.T) testState { httpBin: tb, samples: samples, logger: logger, + loggerHook: hook, callRecorder: recorder, }