From 2c0eb8df9ac4a0745b7d1a7a6098a2e518bd2504 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 18 Mar 2020 13:39:29 +0800 Subject: [PATCH] fix issue #520 --- client.go | 8 ++++++++ cmd/gost/route.go | 11 ++++++----- http2.go | 6 +++--- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index f18bcaf8..c840067f 100644 --- a/client.go +++ b/client.go @@ -83,6 +83,7 @@ type Transporter interface { type DialOptions struct { Timeout time.Duration Chain *Chain + Host string } // DialOption allows a common way to set DialOptions. @@ -102,6 +103,13 @@ func ChainDialOption(chain *Chain) DialOption { } } +// HostDialOption specifies the host used by Transporter.Dial +func HostDialOption(host string) DialOption { + return func(opts *DialOptions) { + opts.Host = host + } +} + // HandshakeOptions describes the options for handshake. type HandshakeOptions struct { Addr string diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 979b39ba..3fc1a9c9 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -234,8 +234,14 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { connector = gost.AutoConnector(node.User) } + host := node.Get("host") + if host == "" { + host = node.Host + } + node.DialOptions = append(node.DialOptions, gost.TimeoutDialOption(timeout), + gost.HostDialOption(host), ) node.ConnectOptions = []gost.ConnectOption{ @@ -244,11 +250,6 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { gost.NoDelayConnectOption(node.GetBool("nodelay")), } - host := node.Get("host") - if host == "" { - host = node.Host - } - sshConfig := &gost.SSHConfig{} if s := node.Get("ssh_key"); s != "" { key, err := gost.ParseSSHKeyFile(s) diff --git a/http2.go b/http2.go index ce4435f0..1c3cf5a9 100644 --- a/http2.go +++ b/http2.go @@ -234,7 +234,7 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err transport := http2.Transport{ TLSClientConfig: tr.tlsConfig, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { conn, err := opts.Chain.Dial(addr) if err != nil { return nil, err @@ -256,13 +256,13 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err pr, pw := io.Pipe() req := &http.Request{ Method: http.MethodConnect, - URL: &url.URL{Scheme: "https", Host: addr}, + URL: &url.URL{Scheme: "https", Host: opts.Host}, Header: make(http.Header), Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, Body: pr, - Host: addr, + Host: opts.Host, ContentLength: -1, } if tr.path != "" {