diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index 5f9e4bdd..b8fda231 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -44,8 +44,9 @@ var ( defaultKeyFile = "key.pem" ) -// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid. -func tlsConfig(certFile, keyFile string) (*tls.Config, error) { +// Load the certificate from cert & key files and optional client CA file, +// will use the default certificate if the provided info are invalid. +func tlsConfig(certFile, keyFile, caFile string) (*tls.Config, error) { if certFile == "" || keyFile == "" { certFile, keyFile = defaultCertFile, defaultKeyFile } @@ -54,7 +55,15 @@ func tlsConfig(certFile, keyFile string) (*tls.Config, error) { if err != nil { return nil, err } - return &tls.Config{Certificates: []tls.Certificate{cert}}, nil + + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + + if pool, _ := loadCA(caFile); pool != nil { + cfg.ClientCAs = pool + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + + return cfg, nil } func loadCA(caFile string) (cp *x509.CertPool, err error) { diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 4a157b2c..f08f4132 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -67,7 +67,7 @@ func main() { } // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. - tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile) + tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile, "") if err != nil { // generate random self-signed certificate. cert, err := gost.GenCertificate() diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 3fc1a9c9..f64056b3 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -128,6 +128,10 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { InsecureSkipVerify: !node.GetBool("secure"), RootCAs: rootCAs, } + if cert, err := tls.LoadX509KeyPair(node.Get("cert"), node.Get("key")); err == nil { + tlsCfg.Certificates = []tls.Certificate{cert} + } + wsOpts := &gost.WSOptions{} wsOpts.EnableCompression = node.GetBool("compression") wsOpts.ReadBufferSize = node.GetInt("rbuf") @@ -343,7 +347,7 @@ func (r *route) GenRouters() ([]router, error) { } } certFile, keyFile := node.Get("cert"), node.Get("key") - tlsCfg, err := tlsConfig(certFile, keyFile) + tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca")) if err != nil && certFile != "" && keyFile != "" { return nil, err }