Skip to content
This repository has been archived by the owner on May 21, 2024. It is now read-only.

Commit

Permalink
Refactoring of the connection params
Browse files Browse the repository at this point in the history
This minor refactoring is a pre-requisition of the following extraction
of the logic of parsing a gRPC metadata.
  • Loading branch information
olegbespalov committed Sep 14, 2023
1 parent 0250ce4 commit 2c61f57
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 109 deletions.
111 changes: 2 additions & 109 deletions grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/grafana/xk6-grpc/lib/netext/grpcext"
"go.k6.io/k6/js/common"
"go.k6.io/k6/js/modules"
"go.k6.io/k6/lib/types"

"github.com/dop251/goja"
"github.com/jhump/protoreflect/desc"
Expand Down Expand Up @@ -205,13 +204,13 @@ func buildTLSConfigFromMap(parentConfig *tls.Config, tlsConfigMap map[string]int
}

// Connect is a block dial to the gRPC server at the given address (host:port)
func (c *Client) Connect(addr string, params map[string]interface{}) (bool, error) {
func (c *Client) Connect(addr string, params goja.Value) (bool, error) {
state := c.vu.State()
if state == nil {
return false, common.NewInitContextError("connecting to a gRPC server in the init context is not supported")
}

p, err := c.parseConnectParams(params)
p, err := newConnectParams(c.vu, params)
if err != nil {
return false, fmt.Errorf("invalid grpc.connect() parameters: %w", err)
}
Expand Down Expand Up @@ -418,112 +417,6 @@ func (c *Client) convertToMethodInfo(fdset *descriptorpb.FileDescriptorSet) ([]M
return rtn, nil
}

type connectParams struct {
IsPlaintext bool
UseReflectionProtocol bool
Timeout time.Duration
MaxReceiveSize int64
MaxSendSize int64
TLS map[string]interface{}
}

func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams, error) {
params := connectParams{
IsPlaintext: false,
UseReflectionProtocol: false,
Timeout: time.Minute,
MaxReceiveSize: 0,
MaxSendSize: 0,
}
for k, v := range raw {
switch k {
case "plaintext":
var ok bool
params.IsPlaintext, ok = v.(bool)
if !ok {
return params, fmt.Errorf("invalid plaintext value: '%#v', it needs to be boolean", v)
}
case "timeout":
var err error
params.Timeout, err = types.GetDurationValue(v)
if err != nil {
return params, fmt.Errorf("invalid timeout value: %w", err)
}
case "reflect":
var ok bool
params.UseReflectionProtocol, ok = v.(bool)
if !ok {
return params, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v)
}
case "maxReceiveSize":
var ok bool
params.MaxReceiveSize, ok = v.(int64)
if !ok {
return params, fmt.Errorf("invalid maxReceiveSize value: '%#v', it needs to be an integer", v)
}
if params.MaxReceiveSize < 0 {
return params, fmt.Errorf("invalid maxReceiveSize value: '%#v, it needs to be a positive integer", v)
}
case "maxSendSize":
var ok bool
params.MaxSendSize, ok = v.(int64)
if !ok {
return params, fmt.Errorf("invalid maxSendSize value: '%#v', it needs to be an integer", v)
}
if params.MaxSendSize < 0 {
return params, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v)
}
case "tls":
if err := parseConnectTLSParam(&params, v); err != nil {
return params, err
}
default:
return params, fmt.Errorf("unknown connect param: %q", k)
}
}
return params, nil
}

func parseConnectTLSParam(params *connectParams, v interface{}) error {
var ok bool
params.TLS, ok = v.(map[string]interface{})

if !ok {
return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v)
}
// optional map keys below
if cert, certok := params.TLS["cert"]; certok {
if _, ok = cert.(string); !ok {
return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v)
}
}
if key, keyok := params.TLS["key"]; keyok {
if _, ok = key.(string); !ok {
return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v)
}
}
if pass, passok := params.TLS["password"]; passok {
if _, ok = pass.(string); !ok {
return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v)
}
}
if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok {
var cacertsArray []interface{}
if cacertsArray, ok = cacerts.([]interface{}); ok {
for _, cacertsArrayEntry := range cacertsArray {
if _, ok = cacertsArrayEntry.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or an array of PEM formatted strings", v)
}
}
} else if _, ok = cacerts.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or an array of PEM formatted strings", v)
}
}
return nil
}

func walkFileDescriptors(seen map[string]struct{}, fd *desc.FileDescriptor) []*descriptorpb.FileDescriptorProto {
fds := []*descriptorpb.FileDescriptorProto{}

Expand Down
119 changes: 119 additions & 0 deletions grpc/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func newCallParams(vu modules.VU, input goja.Value) (*callParams, error) {
if !ok {
return result, errors.New("metadata must be an object with key-value pairs")
}

for hk, kv := range rawHeaders {
var val string

Expand Down Expand Up @@ -99,3 +100,121 @@ func (p *callParams) SetSystemTags(state *lib.State, addr string, methodName str
p.TagsAndMeta.SetSystemTagOrMetaIfEnabled(state.Options.SystemTags, metrics.TagName, methodName)
}
}

// connectParams is the parameters that can be passed to a gRPC connect call.
type connectParams struct {
IsPlaintext bool
UseReflectionProtocol bool
Timeout time.Duration
MaxReceiveSize int64
MaxSendSize int64
TLS map[string]interface{}
}

func newConnectParams(vu modules.VU, input goja.Value) (*connectParams, error) {
result := &connectParams{
IsPlaintext: false,
UseReflectionProtocol: false,
Timeout: time.Minute,
MaxReceiveSize: 0,
MaxSendSize: 0,
}

if input == nil || goja.IsUndefined(input) || goja.IsNull(input) {
return result, nil
}

rt := vu.Runtime()
params := input.ToObject(rt)

for _, k := range params.Keys() {
v := params.Get(k).Export()

switch k {
case "plaintext":
var ok bool
result.IsPlaintext, ok = v.(bool)
if !ok {
return result, fmt.Errorf("invalid plaintext value: '%#v', it needs to be boolean", v)
}
case "timeout":
var err error
result.Timeout, err = types.GetDurationValue(v)
if err != nil {
return result, fmt.Errorf("invalid timeout value: %w", err)
}
case "reflect":
var ok bool
result.UseReflectionProtocol, ok = v.(bool)
if !ok {
return result, fmt.Errorf("invalid reflect value: '%#v', it needs to be boolean", v)
}
case "maxReceiveSize":
var ok bool
result.MaxReceiveSize, ok = v.(int64)
if !ok {
return result, fmt.Errorf("invalid maxReceiveSize value: '%#v', it needs to be an integer", v)
}
if result.MaxReceiveSize < 0 {
return result, fmt.Errorf("invalid maxReceiveSize value: '%#v, it needs to be a positive integer", v)
}
case "maxSendSize":
var ok bool
result.MaxSendSize, ok = v.(int64)
if !ok {
return result, fmt.Errorf("invalid maxSendSize value: '%#v', it needs to be an integer", v)
}
if result.MaxSendSize < 0 {
return result, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v)
}
case "tls":
if err := parseConnectTLSParam(result, v); err != nil {
return result, err
}
default:
return result, fmt.Errorf("unknown connect param: %q", k)
}
}

return result, nil
}

func parseConnectTLSParam(params *connectParams, v interface{}) error {
var ok bool
params.TLS, ok = v.(map[string]interface{})

if !ok {
return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v)
}
// optional map keys below
if cert, certok := params.TLS["cert"]; certok {
if _, ok = cert.(string); !ok {
return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v)
}
}
if key, keyok := params.TLS["key"]; keyok {
if _, ok = key.(string); !ok {
return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v)
}
}
if pass, passok := params.TLS["password"]; passok {
if _, ok = pass.(string); !ok {
return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v)
}
}
if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok {
var cacertsArray []interface{}
if cacertsArray, ok = cacerts.([]interface{}); ok {
for _, cacertsArrayEntry := range cacertsArray {
if _, ok = cacertsArrayEntry.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or an array of PEM formatted strings", v)
}
}
} else if _, ok = cacerts.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or an array of PEM formatted strings", v)
}
}
return nil
}

0 comments on commit 2c61f57

Please sign in to comment.