Skip to content
This repository has been archived by the owner on May 21, 2024. It is now read-only.

Commit

Permalink
This is a port of the PR changes found here: (#25)
Browse files Browse the repository at this point in the history
* This is a port of the PR changes found here:
grafana/k6#3159

Note: The bad auth/bad certificate tests appear to be very inconsistent. In grafana/k6 I had to remove the timeout for them to pass in CI. Removing the timeout in grafana/xk6-grpc caused the tests to fail locally. This may need to be investigated or a bug raised to determine if k6 is not raising grpc connection errors on bad authentication until context deadline.
  • Loading branch information
chrismoran-mica authored Jul 17, 2023
1 parent d2b2ce4 commit fb49221
Show file tree
Hide file tree
Showing 2 changed files with 361 additions and 2 deletions.
150 changes: 148 additions & 2 deletions grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package grpc

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -106,6 +109,101 @@ func (c *Client) LoadProtoset(protosetPath string) ([]MethodInfo, error) {
return c.convertToMethodInfo(fdset)
}

// Note: this function was lifted from `lib/options.go`
func decryptPrivateKey(key, password []byte) ([]byte, error) {
block, _ := pem.Decode(key)
if block == nil {
return nil, errors.New("failed to decode PEM key")
}

blockType := block.Type
if blockType == "ENCRYPTED PRIVATE KEY" {
return nil, errors.New("encrypted pkcs8 formatted key is not supported")
}
/*
Even though `DecryptPEMBlock` has been deprecated since 1.16.x it is still
being used here because it is deprecated due to it not supporting *good* cryptography
ultimately though we want to support something so we will be using it for now.
*/
decryptedKey, err := x509.DecryptPEMBlock(block, password) //nolint:staticcheck
if err != nil {
return nil, err
}
key = pem.EncodeToMemory(&pem.Block{
Type: blockType,
Bytes: decryptedKey,
})
return key, nil
}

func buildTLSConfig(parentConfig *tls.Config, certificate, key []byte, caCertificates [][]byte) (*tls.Config, error) {
var cp *x509.CertPool
if len(caCertificates) > 0 {
cp, _ = x509.SystemCertPool()
for i, caCert := range caCertificates {
if ok := cp.AppendCertsFromPEM(caCert); !ok {
return nil, fmt.Errorf("failed to append ca certificate [%d] from PEM", i)
}
}
}

// Ignoring 'TLS MinVersion is too low' because this tls.Config will inherit MinValue and MaxValue
// from the vu state tls.Config

//nolint:golint,gosec
tlsCfg := &tls.Config{
CipherSuites: parentConfig.CipherSuites,
InsecureSkipVerify: parentConfig.InsecureSkipVerify,
MinVersion: parentConfig.MinVersion,
MaxVersion: parentConfig.MaxVersion,
Renegotiation: parentConfig.Renegotiation,
RootCAs: cp,
}
if len(certificate) > 0 && len(key) > 0 {
cert, err := tls.X509KeyPair(certificate, key)
if err != nil {
return nil, fmt.Errorf("failed to append certificate from PEM: %w", err)
}
tlsCfg.Certificates = []tls.Certificate{cert}
}
return tlsCfg, nil
}

func buildTLSConfigFromMap(parentConfig *tls.Config, tlsConfigMap map[string]interface{}) (*tls.Config, error) {
var cert, key, pass []byte
var ca [][]byte
var err error
if certstr, ok := tlsConfigMap["cert"].(string); ok {
cert = []byte(certstr)
}
if keystr, ok := tlsConfigMap["key"].(string); ok {
key = []byte(keystr)
}
if passwordStr, ok := tlsConfigMap["password"].(string); ok {
pass = []byte(passwordStr)
if len(pass) > 0 {
if key, err = decryptPrivateKey(key, pass); err != nil {
return nil, err
}
}
}
if cas, ok := tlsConfigMap["cacerts"]; ok {
var caCertsArray []interface{}
if caCertsArray, ok = cas.([]interface{}); ok {
ca = make([][]byte, len(caCertsArray))
for i, entry := range caCertsArray {
var entryStr string
if entryStr, ok = entry.(string); ok {
ca[i] = []byte(entryStr)
}
}
} else if caCertStr, caCertStrOk := cas.(string); caCertStrOk {
ca = [][]byte{[]byte(caCertStr)}
}
}
return buildTLSConfig(parentConfig, cert, key, ca)
}

// Connect is a block dial to the gRPC server at the given address (host:port)
func (c *Client) Connect(addr string, params map[string]interface{}) (bool, error) {
state := c.vu.State()
Expand All @@ -123,9 +221,13 @@ func (c *Client) Connect(addr string, params map[string]interface{}) (bool, erro
var tcred credentials.TransportCredentials
if !p.IsPlaintext {
tlsCfg := state.TLSConfig.Clone()
if len(p.TLS) > 0 {
if tlsCfg, err = buildTLSConfigFromMap(tlsCfg, p.TLS); err != nil {
return false, err
}
}
tlsCfg.NextProtos = []string{"h2"}

// TODO(rogchap): Would be good to add support for custom RootCAs (self signed)
tcred = credentials.NewTLS(tlsCfg)
} else {
tcred = insecure.NewCredentials()
Expand Down Expand Up @@ -322,6 +424,7 @@ type connectParams struct {
Timeout time.Duration
MaxReceiveSize int64
MaxSendSize int64
TLS map[string]interface{}
}

func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams, error) {
Expand Down Expand Up @@ -370,14 +473,57 @@ func (c *Client) parseConnectParams(raw map[string]interface{}) (connectParams,
if params.MaxSendSize < 0 {
return params, fmt.Errorf("invalid maxSendSize value: '%#v, it needs to be a positive integer", v)
}

case "tls":
if err := parseConnectTLSParam(&params, v); err != nil {
return params, err
}
default:
return params, fmt.Errorf("unknown connect param: %q", k)
}
}
return params, nil
}

func parseConnectTLSParam(params *connectParams, v interface{}) error {
var ok bool
params.TLS, ok = v.(map[string]interface{})

if !ok {
return fmt.Errorf("invalid tls value: '%#v', expected (optional) keys: cert, key, password, and cacerts", v)
}
// optional map keys below
if cert, certok := params.TLS["cert"]; certok {
if _, ok = cert.(string); !ok {
return fmt.Errorf("invalid tls cert value: '%#v', it needs to be a PEM formatted string", v)
}
}
if key, keyok := params.TLS["key"]; keyok {
if _, ok = key.(string); !ok {
return fmt.Errorf("invalid tls key value: '%#v', it needs to be a PEM formatted string", v)
}
}
if pass, passok := params.TLS["password"]; passok {
if _, ok = pass.(string); !ok {
return fmt.Errorf("invalid tls password value: '%#v', it needs to be a string", v)
}
}
if cacerts, cacertsok := params.TLS["cacerts"]; cacertsok {
var cacertsArray []interface{}
if cacertsArray, ok = cacerts.([]interface{}); ok {
for _, cacertsArrayEntry := range cacertsArray {
if _, ok = cacertsArrayEntry.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or an array of PEM formatted strings", v)
}
}
} else if _, ok = cacerts.(string); !ok {
return fmt.Errorf("invalid tls cacerts value: '%#v',"+
" it needs to be a string or an array of PEM formatted strings", v)
}
}
return nil
}

func walkFileDescriptors(seen map[string]struct{}, fd *desc.FileDescriptor) []*descriptorpb.FileDescriptorProto {
fds := []*descriptorpb.FileDescriptorProto{}

Expand Down
Loading

0 comments on commit fb49221

Please sign in to comment.