From 2c61f57fbe8bfd51c0402de52b1fc9103f50bf09 Mon Sep 17 00:00:00 2001 From: Oleg Bespalov Date: Thu, 14 Sep 2023 10:49:08 +0200 Subject: [PATCH] Refactoring of the connection params This minor refactoring is a pre-requisition of the following extraction of the logic of parsing a gRPC metadata. --- grpc/client.go | 111 +-------------------------------------------- grpc/params.go | 119 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 109 deletions(-) diff --git a/grpc/client.go b/grpc/client.go index 7607d9c..8a56939 100644 --- a/grpc/client.go +++ b/grpc/client.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/xk6-grpc/lib/netext/grpcext" "go.k6.io/k6/js/common" "go.k6.io/k6/js/modules" - "go.k6.io/k6/lib/types" "github.com/dop251/goja" "github.com/jhump/protoreflect/desc" @@ -205,13 +204,13 @@ func buildTLSConfigFromMap(parentConfig *tls.Config, tlsConfigMap map[string]int } // Connect is a block dial to the gRPC server at the given address (host:port) -func (c *Client) Connect(addr string, params map[string]interface{}) (bool, error) { +func (c *Client) Connect(addr string, params goja.Value) (bool, error) { state := c.vu.State() if state == nil { return false, common.NewInitContextError("connecting to a gRPC server in the init context is not supported") } - p, err := c.parseConnectParams(params) + p, err := newConnectParams(c.vu, params) if err != nil { return false, fmt.Errorf("invalid grpc.connect() parameters: %w", err) } @@ -418,112 +417,6 @@ func (c *Client) convertToMethodInfo(fdset *descriptorpb.FileDescriptorSet) ([]M return rtn, nil } -type connectParams struct { - IsPlaintext bool - UseReflectionProtocol bool - Timeout time.Duration - MaxReceiveSize int64 - MaxSendSize int64 - TLS map[string]interface{} -} - -func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams, error) { - params := connectParams{ - IsPlaintext: false, - UseReflectionProtocol: false, - Timeout: time.Minute, - MaxReceiveSize: 0, - MaxSendSize: 0, - } - for k, v := range raw { - switch k { - case "plaintext": - var ok bool - params.IsPlaintext, ok = v.(bool) - if !ok { - return params, fmt.Errorf("invalid plaintext value: '%#v', it needs to be boolean", v) - } - case "timeout": - var err error - params.Timeout, err = types.GetDurationValue(v) - if err != nil { - return params, fmt.Errorf("invalid timeout value: %w", err) - } - case "reflect": - var ok bool - params.UseReflectionProtocol, ok = v.(bool) - if !ok { - return params, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v) - } - case "maxReceiveSize": - var ok bool - params.MaxReceiveSize, ok = v.(int64) - if !ok { - return params, fmt.Errorf("invalid maxReceiveSize value: '%#v', it needs to be an integer", v) - } - if params.MaxReceiveSize < 0 { - return params, fmt.Errorf("invalid maxReceiveSize value: '%#v, it needs to be a positive integer", v) - } - case "maxSendSize": - var ok bool - params.MaxSendSize, ok = v.(int64) - if !ok { - return params, fmt.Errorf("invalid maxSendSize value: '%#v', it needs to be an integer", v) - } - if params.MaxSendSize < 0 { - return params, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v) - } - case "tls": - if err := parseConnectTLSParam(¶ms, v); err != nil { - return params, err - } - default: - return params, fmt.Errorf("unknown connect param: %q", k) - } - } - return params, nil -} - -func parseConnectTLSParam(params *connectParams, v interface{}) error { - var ok bool - params.TLS, ok = v.(map[string]interface{}) - - if !ok { - return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v) - } - // optional map keys below - if cert, certok := params.TLS["cert"]; certok { - if _, ok = cert.(string); !ok { - return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v) - } - } - if key, keyok := params.TLS["key"]; keyok { - if _, ok = key.(string); !ok { - return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v) - } - } - if pass, passok := params.TLS["password"]; passok { - if _, ok = pass.(string); !ok { - return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v) - } - } - if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok { - var cacertsArray []interface{} - if cacertsArray, ok = cacerts.([]interface{}); ok { - for _, cacertsArrayEntry := range cacertsArray { - if _, ok = cacertsArrayEntry.(string); !ok { - return fmt.Errorf("invalid tls cacerts value: '%#v',"+ - " it needs to be a string or an array of PEM formatted strings", v) - } - } - } else if _, ok = cacerts.(string); !ok { - return fmt.Errorf("invalid tls cacerts value: '%#v',"+ - " it needs to be a string or an array of PEM formatted strings", v) - } - } - return nil -} - func walkFileDescriptors(seen map[string]struct{}, fd *desc.FileDescriptor) []*descriptorpb.FileDescriptorProto { fds := []*descriptorpb.FileDescriptorProto{} diff --git a/grpc/params.go b/grpc/params.go index 14e3af9..f891c62 100644 --- a/grpc/params.go +++ b/grpc/params.go @@ -46,6 +46,7 @@ func newCallParams(vu modules.VU, input goja.Value) (*callParams, error) { if !ok { return result, errors.New("metadata must be an object with key-value pairs") } + for hk, kv := range rawHeaders { var val string @@ -99,3 +100,121 @@ func (p *callParams) SetSystemTags(state *lib.State, addr string, methodName str p.TagsAndMeta.SetSystemTagOrMetaIfEnabled(state.Options.SystemTags, metrics.TagName, methodName) } } + +// connectParams is the parameters that can be passed to a gRPC connect call. +type connectParams struct { + IsPlaintext bool + UseReflectionProtocol bool + Timeout time.Duration + MaxReceiveSize int64 + MaxSendSize int64 + TLS map[string]interface{} +} + +func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) { + result := &connectParams{ + IsPlaintext: false, + UseReflectionProtocol: false, + Timeout: time.Minute, + MaxReceiveSize: 0, + MaxSendSize: 0, + } + + if input == nil || goja.IsUndefined(input) || goja.IsNull(input) { + return result, nil + } + + rt := vu.Runtime() + params := input.ToObject(rt) + + for _, k := range params.Keys() { + v := params.Get(k).Export() + + switch k { + case "plaintext": + var ok bool + result.IsPlaintext, ok = v.(bool) + if !ok { + return result, fmt.Errorf("invalid plaintext value: '%#v', it needs to be boolean", v) + } + case "timeout": + var err error + result.Timeout, err = types.GetDurationValue(v) + if err != nil { + return result, fmt.Errorf("invalid timeout value: %w", err) + } + case "reflect": + var ok bool + result.UseReflectionProtocol, ok = v.(bool) + if !ok { + return result, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v) + } + case "maxReceiveSize": + var ok bool + result.MaxReceiveSize, ok = v.(int64) + if !ok { + return result, fmt.Errorf("invalid maxReceiveSize value: '%#v', it needs to be an integer", v) + } + if result.MaxReceiveSize < 0 { + return result, fmt.Errorf("invalid maxReceiveSize value: '%#v, it needs to be a positive integer", v) + } + case "maxSendSize": + var ok bool + result.MaxSendSize, ok = v.(int64) + if !ok { + return result, fmt.Errorf("invalid maxSendSize value: '%#v', it needs to be an integer", v) + } + if result.MaxSendSize < 0 { + return result, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v) + } + case "tls": + if err := parseConnectTLSParam(result, v); err != nil { + return result, err + } + default: + return result, fmt.Errorf("unknown connect param: %q", k) + } + } + + return result, nil +} + +func parseConnectTLSParam(params *connectParams, v interface{}) error { + var ok bool + params.TLS, ok = v.(map[string]interface{}) + + if !ok { + return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v) + } + // optional map keys below + if cert, certok := params.TLS["cert"]; certok { + if _, ok = cert.(string); !ok { + return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v) + } + } + if key, keyok := params.TLS["key"]; keyok { + if _, ok = key.(string); !ok { + return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v) + } + } + if pass, passok := params.TLS["password"]; passok { + if _, ok = pass.(string); !ok { + return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v) + } + } + if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok { + var cacertsArray []interface{} + if cacertsArray, ok = cacerts.([]interface{}); ok { + for _, cacertsArrayEntry := range cacertsArray { + if _, ok = cacertsArrayEntry.(string); !ok { + return fmt.Errorf("invalid tls cacerts value: '%#v',"+ + " it needs to be a string or an array of PEM formatted strings", v) + } + } + } else if _, ok = cacerts.(string); !ok { + return fmt.Errorf("invalid tls cacerts value: '%#v',"+ + " it needs to be a string or an array of PEM formatted strings", v) + } + } + return nil +}