Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions args_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type kingpinParser struct {
stream bool
certPath string
keyPath string
keyLogPath string
rate *nullableUint64
clientType clientTyp

Expand All @@ -54,6 +55,7 @@ func newKingpinParser() argsParser {
method: "GET",
body: "",
bodyFilePath: "",
keyLogPath: "",
stream: false,
certPath: "",
keyPath: "",
Expand Down Expand Up @@ -102,6 +104,9 @@ func newKingpinParser() argsParser {
app.Flag("key", "Path to the client's TLS Certificate Private Key").
Default("").
StringVar(&kparser.keyPath)
app.Flag("key-log-path", "Path used to log TLS keys for Wireshark").
Default("").
StringVar(&kparser.keyLogPath)
app.Flag("insecure",
"Controls whether a client verifies the server's certificate"+
" chain and host name").
Expand Down Expand Up @@ -222,6 +227,7 @@ func (k *kingpinParser) parse(args []string) (config, error) {
bodyFilePath: k.bodyFilePath,
stream: k.stream,
keyPath: k.keyPath,
keyLogPath: k.keyLogPath,
certPath: k.certPath,
printLatencies: k.latencies,
insecure: k.insecure,
Expand Down
3 changes: 3 additions & 0 deletions args_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,14 @@ func TestArgsParsing(t *testing.T) {
programName,
"--key", "testclient.key",
"--cert", "testclient.cert",
"--key-log-path", "/path/to/keylog",
"https://somehost.somedomain",
},
{
programName,
"--key=testclient.key",
"--cert=testclient.cert",
"--key-log-path=/path/to/keylog",
"https://somehost.somedomain",
},
},
Expand All @@ -214,6 +216,7 @@ func TestArgsParsing(t *testing.T) {
method: "GET",
keyPath: "testclient.key",
certPath: "testclient.cert",
keyLogPath: "/path/to/keylog",
url: ParseURLOrPanic("https://somehost.somedomain"),
printIntro: true,
printProgress: true,
Expand Down
16 changes: 14 additions & 2 deletions client_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"crypto/tls"
"io"
"os"
)

// readClientCert - helper function to read client certificate
Expand All @@ -16,8 +18,9 @@ func readClientCert(certPath, keyPath string) ([]tls.Certificate, error) {
// config
func generateTLSConfig(c config) (*tls.Config, error) {
var (
certs []tls.Certificate
err error
certs []tls.Certificate
keyLogWriter io.Writer
err error
)
// This assumes that the caller has validated that either both or none of
// the c.certPath and c.keyPath are set.
Expand All @@ -28,12 +31,21 @@ func generateTLSConfig(c config) (*tls.Config, error) {
}
}

if c.keyLogPath != "" {
f, err := os.OpenFile(c.keyLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return nil, err
}

keyLogWriter = f
}
// Disable gas warning, because InsecureSkipVerify may be set to true
// for the purpose of testing
/* #nosec */
tlsConfig := &tls.Config{
InsecureSkipVerify: c.insecure,
Certificates: certs,
KeyLogWriter: keyLogWriter,
}
return tlsConfig, nil
}
23 changes: 23 additions & 0 deletions client_cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,26 @@ func TestGenerateTLSConfig(t *testing.T) {
}
}
}

func TestInvalidKeyLogPath(t *testing.T) {
expectations := []struct {
keyLogPath string
errIsNil bool
}{
{
keyLogPath: "/path/to/invalid/log",
errIsNil: false,
},
}
for _, e := range expectations {
_, err := generateTLSConfig(
config{
url: ParseURLOrPanic("https://doesnt.exist.com"),
keyLogPath: e.keyLogPath,
},
)
if (err == nil) != e.errIsNil {
t.Error(e.keyLogPath, err)
}
}
}
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type config struct {
url *url.URL
method, certPath, keyPath string
body, bodyFilePath string
keyLogPath string
stream bool
headers *headersList
timeout time.Duration
Expand Down