Skip to content

Commit

Permalink
support mutual TLS authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
ginuerzh committed May 23, 2020
1 parent 60d7e01 commit b285bcd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
15 changes: 12 additions & 3 deletions cmd/gost/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ var (
defaultKeyFile = "key.pem"
)

// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid.
func tlsConfig(certFile, keyFile string) (*tls.Config, error) {
// Load the certificate from cert & key files and optional client CA file,
// will use the default certificate if the provided info are invalid.
func tlsConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
if certFile == "" || keyFile == "" {
certFile, keyFile = defaultCertFile, defaultKeyFile
}
Expand All @@ -54,7 +55,15 @@ func tlsConfig(certFile, keyFile string) (*tls.Config, error) {
if err != nil {
return nil, err
}
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil

cfg := &tls.Config{Certificates: []tls.Certificate{cert}}

if pool, _ := loadCA(caFile); pool != nil {
cfg.ClientCAs = pool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}

return cfg, nil
}

func loadCA(caFile string) (cp *x509.CertPool, err error) {
Expand Down
2 changes: 1 addition & 1 deletion cmd/gost/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func main() {
}

// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile)
tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile, "")
if err != nil {
// generate random self-signed certificate.
cert, err := gost.GenCertificate()
Expand Down
6 changes: 5 additions & 1 deletion cmd/gost/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
InsecureSkipVerify: !node.GetBool("secure"),
RootCAs: rootCAs,
}
if cert, err := tls.LoadX509KeyPair(node.Get("cert"), node.Get("key")); err == nil {
tlsCfg.Certificates = []tls.Certificate{cert}
}

wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize = node.GetInt("rbuf")
Expand Down Expand Up @@ -343,7 +347,7 @@ func (r *route) GenRouters() ([]router, error) {
}
}
certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile)
tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca"))
if err != nil && certFile != "" && keyFile != "" {
return nil, err
}
Expand Down

0 comments on commit b285bcd

Please sign in to comment.