diff --git a/core/connection/tlsconfig.go b/core/connection/tlsconfig.go index 337e3e9a17..daf4a25d72 100644 --- a/core/connection/tlsconfig.go +++ b/core/connection/tlsconfig.go @@ -1,6 +1,7 @@ package connection import ( + "bytes" "crypto/tls" "crypto/x509" "encoding/asn1" @@ -9,10 +10,14 @@ import ( "errors" "fmt" "io/ioutil" + "strings" ) // TLSConfig contains options for configuring a TLS connection to the server. -type TLSConfig struct{ *tls.Config } +type TLSConfig struct { + *tls.Config + clientCertPass func() string +} // NewTLSConfig creates a new TLSConfig. func NewTLSConfig() *TLSConfig { @@ -22,6 +27,13 @@ func NewTLSConfig() *TLSConfig { return cfg } +// SetClientCertDecryptPassword sets a function to retrieve the decryption password +// necessary to read a certificate. This is a function instead of a string to +// provide greater flexibility when deciding how to retrieve and store the password. +func (c *TLSConfig) SetClientCertDecryptPassword(f func() string) { + c.clientCertPass = f +} + // SetInsecure sets whether the client should verify the server's certificate // chain and hostnames. func (c *TLSConfig) SetInsecure(allow bool) { @@ -63,7 +75,46 @@ func (c *TLSConfig) AddClientCertFromFile(clientFile string) (string, error) { return "", err } - cert, err := tls.X509KeyPair(data, data) + var currentBlock *pem.Block + var certBlock, certDecodedBlock, keyBlock []byte + + remaining := data + start := 0 + for { + currentBlock, remaining = pem.Decode(remaining) + if currentBlock == nil { + break + } + + if currentBlock.Type == "CERTIFICATE" { + certBlock = data[start : len(data)-len(remaining)] + certDecodedBlock = currentBlock.Bytes + start += len(certBlock) + } else if strings.HasSuffix(currentBlock.Type, "PRIVATE KEY") { + if c.clientCertPass != nil && x509.IsEncryptedPEMBlock(currentBlock) { + var encoded bytes.Buffer + buf, err := x509.DecryptPEMBlock(currentBlock, []byte(c.clientCertPass())) + if err != nil { + return "", err + } + + pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: buf}) + keyBlock = encoded.Bytes() + start = len(data) - len(remaining) + } else { + keyBlock = data[start : len(data)-len(remaining)] + start += len(keyBlock) + } + } + } + if len(certBlock) == 0 { + return "", fmt.Errorf("failed to find CERTIFICATE") + } + if len(keyBlock) == 0 { + return "", fmt.Errorf("failed to find PRIVATE KEY") + } + + cert, err := tls.X509KeyPair(certBlock, keyBlock) if err != nil { return "", err } @@ -71,15 +122,8 @@ func (c *TLSConfig) AddClientCertFromFile(clientFile string) (string, error) { c.Certificates = append(c.Certificates, cert) // The documentation for the tls.X509KeyPair indicates that the Leaf certificate is not - // retained. Because there isn't any way of creating a tls.Certificate from an x509.Certificate - // short of calling X509KeyPair on the raw bytes, we're forced to parse the certificate over - // again to get the subject name. - certBytes, err := loadCert(data) - if err != nil { - return "", err - } - - crt, err := x509.ParseCertificate(certBytes) + // retained. + crt, err := x509.ParseCertificate(certDecodedBlock) if err != nil { return "", err } diff --git a/core/connection/tlsconfig_clone_17.go b/core/connection/tlsconfig_clone_17.go index c5ff6d4cbf..4b52364fdc 100644 --- a/core/connection/tlsconfig_clone_17.go +++ b/core/connection/tlsconfig_clone_17.go @@ -6,7 +6,7 @@ import "crypto/tls" // used concurrently by a TLS client or server. func (c *TLSConfig) Clone() *TLSConfig { cfg := cloneconfig(c.Config) - return &TLSConfig{cfg} + return &TLSConfig{cfg, c.clientCertPass} } func cloneconfig(c *tls.Config) *tls.Config { diff --git a/core/connstring/connstring.go b/core/connstring/connstring.go index 88e54f3693..aa9524576c 100644 --- a/core/connstring/connstring.go +++ b/core/connstring/connstring.go @@ -32,52 +32,54 @@ func Parse(s string) (ConnString, error) { // ConnString represents a connection string to mongodb. type ConnString struct { - Original string - AppName string - AuthMechanism string - AuthMechanismProperties map[string]string - AuthSource string - Connect ConnectMode - ConnectSet bool - ConnectTimeout time.Duration - ConnectTimeoutSet bool - Database string - HeartbeatInterval time.Duration - HeartbeatIntervalSet bool - Hosts []string - J bool - JSet bool - LocalThreshold time.Duration - LocalThresholdSet bool - MaxConnIdleTime time.Duration - MaxConnIdleTimeSet bool - MaxConnLifeTime time.Duration - MaxConnsPerHost uint16 - MaxConnsPerHostSet bool - MaxIdleConnsPerHost uint16 - MaxIdleConnsPerHostSet bool - Password string - PasswordSet bool - ReadConcernLevel string - ReadPreference string - ReadPreferenceTagSets []map[string]string - ReplicaSet string - ServerSelectionTimeout time.Duration - ServerSelectionTimeoutSet bool - SocketTimeout time.Duration - SocketTimeoutSet bool - SSL bool - SSLSet bool - SSLClientCertificateKeyFile string - SSLClientCertificateKeyFileSet bool - SSLInsecure bool - SSLInsecureSet bool - SSLCaFile string - SSLCaFileSet bool - WString string - WNumber int - WNumberSet bool - Username string + Original string + AppName string + AuthMechanism string + AuthMechanismProperties map[string]string + AuthSource string + Connect ConnectMode + ConnectSet bool + ConnectTimeout time.Duration + ConnectTimeoutSet bool + Database string + HeartbeatInterval time.Duration + HeartbeatIntervalSet bool + Hosts []string + J bool + JSet bool + LocalThreshold time.Duration + LocalThresholdSet bool + MaxConnIdleTime time.Duration + MaxConnIdleTimeSet bool + MaxConnLifeTime time.Duration + MaxConnsPerHost uint16 + MaxConnsPerHostSet bool + MaxIdleConnsPerHost uint16 + MaxIdleConnsPerHostSet bool + Password string + PasswordSet bool + ReadConcernLevel string + ReadPreference string + ReadPreferenceTagSets []map[string]string + ReplicaSet string + ServerSelectionTimeout time.Duration + ServerSelectionTimeoutSet bool + SocketTimeout time.Duration + SocketTimeoutSet bool + SSL bool + SSLSet bool + SSLClientCertificateKeyFile string + SSLClientCertificateKeyFileSet bool + SSLClientCertificateKeyPassword func() string + SSLClientCertificateKeyPasswordSet bool + SSLInsecure bool + SSLInsecureSet bool + SSLCaFile string + SSLCaFileSet bool + WString string + WNumber int + WNumberSet bool + Username string WTimeout time.Duration WTimeoutSet bool @@ -472,6 +474,9 @@ func (p *parser) addOption(pair string) error { p.SSLSet = true p.SSLClientCertificateKeyFile = value p.SSLClientCertificateKeyFileSet = true + case "sslclientcertificatekeypassword": + p.SSLClientCertificateKeyPassword = func() string { return value } + p.SSLClientCertificateKeyPasswordSet = true case "sslinsecure": switch value { case "true": diff --git a/core/topology/topology_options.go b/core/topology/topology_options.go index b66a334880..2361e19d3f 100644 --- a/core/topology/topology_options.go +++ b/core/topology/topology_options.go @@ -101,6 +101,9 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option } if cs.SSLClientCertificateKeyFileSet { + if cs.SSLClientCertificateKeyPasswordSet && cs.SSLClientCertificateKeyPassword != nil { + tlsConfig.SetClientCertDecryptPassword(cs.SSLClientCertificateKeyPassword) + } s, err := tlsConfig.AddClientCertFromFile(cs.SSLClientCertificateKeyFile) if err != nil { return err diff --git a/mongo/client_options.go b/mongo/client_options.go index 5ad39ba46f..16491f28e1 100644 --- a/mongo/client_options.go +++ b/mongo/client_options.go @@ -323,6 +323,19 @@ func (co *ClientOptions) SSLClientCertificateKeyFile(s string) *ClientOptions { return &ClientOptions{next: co, opt: fn} } +// SSLClientCertificateKeyPassword provides a callback that returns a password used for decrypting the +// private key of a PEM file (if one is provided). +func (co *ClientOptions) SSLClientCertificateKeyPassword(s func() string) *ClientOptions { + var fn option = func(c *Client) error { + if !c.connString.SSLClientCertificateKeyPasswordSet { + c.connString.SSLClientCertificateKeyPassword = s + c.connString.SSLClientCertificateKeyPasswordSet = true + } + return nil + } + return &ClientOptions{next: co, opt: fn} +} + // SSLInsecure indicates whether to skip the verification of the server certificate and hostname. func (co *ClientOptions) SSLInsecure(b bool) *ClientOptions { var fn option = func(c *Client) error { diff --git a/mongo/client_options_test.go b/mongo/client_options_test.go index 8d52b0f8e9..dfdfeb6b24 100644 --- a/mongo/client_options_test.go +++ b/mongo/client_options_test.go @@ -86,6 +86,7 @@ func TestClientOptions_chainAll(t *testing.T) { SocketTimeout(2 * time.Second). SSL(true). SSLClientCertificateKeyFile("client.pem"). + SSLClientCertificateKeyPassword(func() string { return "password" }). SSLInsecure(false). SSLCaFile("ca.pem"). WString("majority").