diff --git a/client.go b/client.go index 962c06a39..524e54268 100644 --- a/client.go +++ b/client.go @@ -99,6 +99,9 @@ type Dialer struct { // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar http.CookieJar + + // Custom proxy connect header + ProxyConnectHeader http.Header } // Dial creates a new client connection by calling DialContext with a background context. @@ -274,7 +277,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h return nil, nil, err } if proxyURL != nil { - dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) + dialer, err := proxy_FromURL(proxyURL, &netDialer{d.ProxyConnectHeader, netDial}) if err != nil { return nil, nil, err } diff --git a/client_server_test.go b/client_server_test.go index 7e7636f4e..5c544926a 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -156,6 +156,9 @@ func TestProxyDial(t *testing.T) { cstDialer := cstDialer // make local copy for modification on next line. cstDialer.Proxy = http.ProxyURL(surl) + cstDialer.ProxyConnectHeader = map[string][]string{ + "User-Agents": {"xxx"}, + } connect := false origHandler := s.Server.Config.Handler @@ -166,6 +169,10 @@ func TestProxyDial(t *testing.T) { if r.Method == "CONNECT" { connect = true w.WriteHeader(http.StatusOK) + if r.Header.Get("User-Agents") != "xxx" { + t.Log("xxx not found in the request header") + http.Error(w, "header xxx not found", http.StatusMethodNotAllowed) + } return } diff --git a/proxy.go b/proxy.go index e87a8c9f0..51b7e7229 100644 --- a/proxy.go +++ b/proxy.go @@ -14,21 +14,26 @@ import ( "strings" ) -type netDialerFunc func(network, addr string) (net.Conn, error) +type netDialer struct { + proxyHeader http.Header + f func(network, addr string) (net.Conn, error) +} -func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { - return fn(network, addr) +func (n netDialer) Dial(network, addr string) (net.Conn, error) { + return n.f(network, addr) } func init() { proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { - return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil + p, _ := forwardDialer.(*netDialer) + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, proxyHeader: p.proxyHeader}, nil }) } type httpProxyDialer struct { proxyURL *url.URL forwardDial func(network, addr string) (net.Conn, error) + proxyHeader http.Header } func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { @@ -47,6 +52,10 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) } } + for k, v := range hpd.proxyHeader { + connectHeader[k] = v + } + connectReq := &http.Request{ Method: "CONNECT", URL: &url.URL{Opaque: addr},