Skip to content

Commit

Permalink
Merge branch 'jtv4k-master'
Browse files Browse the repository at this point in the history
Change-Id: I84b0b443c8eec74a4ade17fd48804934876c841a
  • Loading branch information
skriptble committed Apr 10, 2018
2 parents 915670b + b9b5bc1 commit 5f9341c
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 58 deletions.
66 changes: 55 additions & 11 deletions core/connection/tlsconfig.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package connection

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
Expand All @@ -9,10 +10,14 @@ import (
"errors"
"fmt"
"io/ioutil"
"strings"
)

// TLSConfig contains options for configuring a TLS connection to the server.
type TLSConfig struct{ *tls.Config }
type TLSConfig struct {
*tls.Config
clientCertPass func() string
}

// NewTLSConfig creates a new TLSConfig.
func NewTLSConfig() *TLSConfig {
Expand All @@ -22,6 +27,13 @@ func NewTLSConfig() *TLSConfig {
return cfg
}

// SetClientCertDecryptPassword sets a function to retrieve the decryption password
// necessary to read a certificate. This is a function instead of a string to
// provide greater flexibility when deciding how to retrieve and store the password.
func (c *TLSConfig) SetClientCertDecryptPassword(f func() string) {
c.clientCertPass = f
}

// SetInsecure sets whether the client should verify the server's certificate
// chain and hostnames.
func (c *TLSConfig) SetInsecure(allow bool) {
Expand Down Expand Up @@ -63,23 +75,55 @@ func (c *TLSConfig) AddClientCertFromFile(clientFile string) (string, error) {
return "", err
}

cert, err := tls.X509KeyPair(data, data)
var currentBlock *pem.Block
var certBlock, certDecodedBlock, keyBlock []byte

remaining := data
start := 0
for {
currentBlock, remaining = pem.Decode(remaining)
if currentBlock == nil {
break
}

if currentBlock.Type == "CERTIFICATE" {
certBlock = data[start : len(data)-len(remaining)]
certDecodedBlock = currentBlock.Bytes
start += len(certBlock)
} else if strings.HasSuffix(currentBlock.Type, "PRIVATE KEY") {
if c.clientCertPass != nil && x509.IsEncryptedPEMBlock(currentBlock) {
var encoded bytes.Buffer
buf, err := x509.DecryptPEMBlock(currentBlock, []byte(c.clientCertPass()))
if err != nil {
return "", err
}

pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: buf})
keyBlock = encoded.Bytes()
start = len(data) - len(remaining)
} else {
keyBlock = data[start : len(data)-len(remaining)]
start += len(keyBlock)
}
}
}
if len(certBlock) == 0 {
return "", fmt.Errorf("failed to find CERTIFICATE")
}
if len(keyBlock) == 0 {
return "", fmt.Errorf("failed to find PRIVATE KEY")
}

cert, err := tls.X509KeyPair(certBlock, keyBlock)
if err != nil {
return "", err
}

c.Certificates = append(c.Certificates, cert)

// The documentation for the tls.X509KeyPair indicates that the Leaf certificate is not
// retained. Because there isn't any way of creating a tls.Certificate from an x509.Certificate
// short of calling X509KeyPair on the raw bytes, we're forced to parse the certificate over
// again to get the subject name.
certBytes, err := loadCert(data)
if err != nil {
return "", err
}

crt, err := x509.ParseCertificate(certBytes)
// retained.
crt, err := x509.ParseCertificate(certDecodedBlock)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion core/connection/tlsconfig_clone_17.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import "crypto/tls"
// used concurrently by a TLS client or server.
func (c *TLSConfig) Clone() *TLSConfig {
cfg := cloneconfig(c.Config)
return &TLSConfig{cfg}
return &TLSConfig{cfg, c.clientCertPass}
}

func cloneconfig(c *tls.Config) *tls.Config {
Expand Down
97 changes: 51 additions & 46 deletions core/connstring/connstring.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,52 +32,54 @@ func Parse(s string) (ConnString, error) {

// ConnString represents a connection string to mongodb.
type ConnString struct {
Original string
AppName string
AuthMechanism string
AuthMechanismProperties map[string]string
AuthSource string
Connect ConnectMode
ConnectSet bool
ConnectTimeout time.Duration
ConnectTimeoutSet bool
Database string
HeartbeatInterval time.Duration
HeartbeatIntervalSet bool
Hosts []string
J bool
JSet bool
LocalThreshold time.Duration
LocalThresholdSet bool
MaxConnIdleTime time.Duration
MaxConnIdleTimeSet bool
MaxConnLifeTime time.Duration
MaxConnsPerHost uint16
MaxConnsPerHostSet bool
MaxIdleConnsPerHost uint16
MaxIdleConnsPerHostSet bool
Password string
PasswordSet bool
ReadConcernLevel string
ReadPreference string
ReadPreferenceTagSets []map[string]string
ReplicaSet string
ServerSelectionTimeout time.Duration
ServerSelectionTimeoutSet bool
SocketTimeout time.Duration
SocketTimeoutSet bool
SSL bool
SSLSet bool
SSLClientCertificateKeyFile string
SSLClientCertificateKeyFileSet bool
SSLInsecure bool
SSLInsecureSet bool
SSLCaFile string
SSLCaFileSet bool
WString string
WNumber int
WNumberSet bool
Username string
Original string
AppName string
AuthMechanism string
AuthMechanismProperties map[string]string
AuthSource string
Connect ConnectMode
ConnectSet bool
ConnectTimeout time.Duration
ConnectTimeoutSet bool
Database string
HeartbeatInterval time.Duration
HeartbeatIntervalSet bool
Hosts []string
J bool
JSet bool
LocalThreshold time.Duration
LocalThresholdSet bool
MaxConnIdleTime time.Duration
MaxConnIdleTimeSet bool
MaxConnLifeTime time.Duration
MaxConnsPerHost uint16
MaxConnsPerHostSet bool
MaxIdleConnsPerHost uint16
MaxIdleConnsPerHostSet bool
Password string
PasswordSet bool
ReadConcernLevel string
ReadPreference string
ReadPreferenceTagSets []map[string]string
ReplicaSet string
ServerSelectionTimeout time.Duration
ServerSelectionTimeoutSet bool
SocketTimeout time.Duration
SocketTimeoutSet bool
SSL bool
SSLSet bool
SSLClientCertificateKeyFile string
SSLClientCertificateKeyFileSet bool
SSLClientCertificateKeyPassword func() string
SSLClientCertificateKeyPasswordSet bool
SSLInsecure bool
SSLInsecureSet bool
SSLCaFile string
SSLCaFileSet bool
WString string
WNumber int
WNumberSet bool
Username string

WTimeout time.Duration
WTimeoutSet bool
Expand Down Expand Up @@ -472,6 +474,9 @@ func (p *parser) addOption(pair string) error {
p.SSLSet = true
p.SSLClientCertificateKeyFile = value
p.SSLClientCertificateKeyFileSet = true
case "sslclientcertificatekeypassword":
p.SSLClientCertificateKeyPassword = func() string { return value }
p.SSLClientCertificateKeyPasswordSet = true
case "sslinsecure":
switch value {
case "true":
Expand Down
3 changes: 3 additions & 0 deletions core/topology/topology_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option
}

if cs.SSLClientCertificateKeyFileSet {
if cs.SSLClientCertificateKeyPasswordSet && cs.SSLClientCertificateKeyPassword != nil {
tlsConfig.SetClientCertDecryptPassword(cs.SSLClientCertificateKeyPassword)
}
s, err := tlsConfig.AddClientCertFromFile(cs.SSLClientCertificateKeyFile)
if err != nil {
return err
Expand Down
13 changes: 13 additions & 0 deletions mongo/client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,19 @@ func (co *ClientOptions) SSLClientCertificateKeyFile(s string) *ClientOptions {
return &ClientOptions{next: co, opt: fn}
}

// SSLClientCertificateKeyPassword provides a callback that returns a password used for decrypting the
// private key of a PEM file (if one is provided).
func (co *ClientOptions) SSLClientCertificateKeyPassword(s func() string) *ClientOptions {
var fn option = func(c *Client) error {
if !c.connString.SSLClientCertificateKeyPasswordSet {
c.connString.SSLClientCertificateKeyPassword = s
c.connString.SSLClientCertificateKeyPasswordSet = true
}
return nil
}
return &ClientOptions{next: co, opt: fn}
}

// SSLInsecure indicates whether to skip the verification of the server certificate and hostname.
func (co *ClientOptions) SSLInsecure(b bool) *ClientOptions {
var fn option = func(c *Client) error {
Expand Down
1 change: 1 addition & 0 deletions mongo/client_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ func TestClientOptions_chainAll(t *testing.T) {
SocketTimeout(2 * time.Second).
SSL(true).
SSLClientCertificateKeyFile("client.pem").
SSLClientCertificateKeyPassword(func() string { return "password" }).
SSLInsecure(false).
SSLCaFile("ca.pem").
WString("majority").
Expand Down

0 comments on commit 5f9341c

Please sign in to comment.