diff --git a/README.md b/README.md index ba191cf..c87bcbe 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,10 @@ func main() { ### Modifying requests and responses -You can modify requests and responses by setting `OnRequest` and `OnResponse` handlers. +You can modify requests and responses using `OnRequest` and `OnResponse` handlers. + +The example below will block requests to `example.net` and add a short comment to +the end of every HTML response. ```go proxy := gomitmproxy.NewProxy(gomitmproxy.Config{ @@ -89,7 +92,39 @@ proxy := gomitmproxy.NewProxy(gomitmproxy.Config{ log.Printf("onResponse: was blocked") } - return nil + res := session.Response() + req := session.Request() + + if strings.Index(res.Header.Get("Content-Type"), "text/html") != 0 { + // Do nothing with non-HTML responses + return nil + } + + b, err := proxyutil.ReadDecompressedBody(res) + // Close the original body + _ = res.Body.Close() + if err != nil { + return proxyutil.NewErrorResponse(req, err) + } + + // Use latin1 before modifying the body + // Using this 1-byte encoding will let us preserve all original characters + // regardless of what exactly is the encoding + body, err := proxyutil.DecodeLatin1(bytes.NewReader(b)) + if err != nil { + return proxyutil.NewErrorResponse(session.Request(), err) + } + + // Modifying the original body + modifiedBody, err := proxyutil.EncodeLatin1(body + "") + if err != nil { + return proxyutil.NewErrorResponse(session.Request(), err) + } + + res.Body = ioutil.NopCloser(bytes.NewReader(modifiedBody)) + res.Header.Del("Content-Encoding") + res.ContentLength = int64(len(modifiedBody)) + return res }, }) ``` @@ -231,6 +266,8 @@ mitmConfig, err := mitm.NewConfig(x509c, privateKey, &CustomCertsStorage{ * [X] Support HTTP CONNECT over TLS * [X] Test plain HTTP requests inside HTTP CONNECT * [X] Test memory leaks + * [X] Editing response body in a callback + * [X] Handle unknown content-encoding values * [ ] Unit tests * [ ] Check & fix TODOs * [ ] MITM diff --git a/auth.go b/auth.go index 39e2aca..c6994fb 100644 --- a/auth.go +++ b/auth.go @@ -4,6 +4,8 @@ import ( "encoding/base64" "net/http" "strings" + + "github.com/AdguardTeam/gomitmproxy/proxyutil" ) // See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt @@ -18,7 +20,7 @@ func basicAuth(username, password string) string { // newNotAuthorizedResponse creates a new "407 (Proxy Authentication Required)" response func newNotAuthorizedResponse(session *Session) *http.Response { - res := NewResponse(http.StatusProxyAuthRequired, nil, session.req) + res := proxyutil.NewResponse(http.StatusProxyAuthRequired, nil, session.req) // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Proxy-Authenticate res.Header.Set("Proxy-Authenticate", "Basic") diff --git a/examples/mitm/main.go b/examples/mitm/main.go index 667412c..837ac1e 100644 --- a/examples/mitm/main.go +++ b/examples/mitm/main.go @@ -1,9 +1,11 @@ package main import ( + "bytes" "crypto/rsa" "crypto/tls" "crypto/x509" + "io/ioutil" "net" "net/http" "os" @@ -12,6 +14,8 @@ import ( "syscall" "time" + "github.com/AdguardTeam/gomitmproxy/proxyutil" + "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/gomitmproxy" "github.com/AdguardTeam/gomitmproxy/mitm" @@ -49,13 +53,13 @@ func main() { mitmConfig.SetOrganization("gomitmproxy") // cert organization // GENERATE A CERT FOR HTTP OVER TLS PROXY - //proxyCert, err := mitmConfig.GetOrCreateCert("127.0.0.1") - //if err != nil { - // panic(err) - //} - //tlsConfig := &tls.Config{ - // Certificates: []tls.Certificate{*proxyCert}, - //} + proxyCert, err := mitmConfig.GetOrCreateCert("127.0.0.1") + if err != nil { + panic(err) + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*proxyCert}, + } // PREPARE PROXY addr := &net.TCPAddr{ @@ -65,7 +69,7 @@ func main() { proxy := gomitmproxy.NewProxy(gomitmproxy.Config{ ListenAddr: addr, - // TLSConfig: tlsConfig, + TLSConfig: tlsConfig, Username: "user", Password: "pass", @@ -98,7 +102,7 @@ func onRequest(session *gomitmproxy.Session) (*http.Request, *http.Response) { if req.URL.Host == "example.net" { body := strings.NewReader("

Replaced response

") - res := gomitmproxy.NewResponse(http.StatusOK, body, req) + res := proxyutil.NewResponse(http.StatusOK, body, req) res.Header.Set("Content-Type", "text/html") session.SetProp("blocked", true) return nil, res @@ -112,9 +116,42 @@ func onResponse(session *gomitmproxy.Session) *http.Response { if _, ok := session.GetProp("blocked"); ok { log.Printf("onResponse: was blocked") + return nil + } + + res := session.Response() + req := session.Request() + + if strings.Index(res.Header.Get("Content-Type"), "text/html") != 0 { + // Do nothing with non-HTML responses + return nil + } + + b, err := proxyutil.ReadDecompressedBody(res) + // Close the original body + _ = res.Body.Close() + if err != nil { + return proxyutil.NewErrorResponse(req, err) + } + + // Use latin1 before modifying the body + // Using this 1-byte encoding will let us preserve all original characters + // regardless of what exactly is the encoding + body, err := proxyutil.DecodeLatin1(bytes.NewReader(b)) + if err != nil { + return proxyutil.NewErrorResponse(session.Request(), err) + } + + // Modifying the original body + modifiedBody, err := proxyutil.EncodeLatin1(body + "") + if err != nil { + return proxyutil.NewErrorResponse(session.Request(), err) } - return nil + res.Body = ioutil.NopCloser(bytes.NewReader(modifiedBody)) + res.Header.Del("Content-Encoding") + res.ContentLength = int64(len(modifiedBody)) + return res } // CustomCertsStorage - an example of a custom cert storage diff --git a/go.mod b/go.mod index d2f9812..73594d6 100644 --- a/go.mod +++ b/go.mod @@ -7,4 +7,5 @@ require ( github.com/pkg/errors v0.8.1 github.com/prometheus/common v0.7.0 github.com/stretchr/testify v1.4.0 + golang.org/x/text v0.3.0 ) diff --git a/go.sum b/go.sum index 2345b3a..8a57706 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,7 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= diff --git a/helper.go b/helper.go index 53de3fe..110e6e6 100644 --- a/helper.go +++ b/helper.go @@ -1,14 +1,9 @@ package gomitmproxy import ( - "bytes" "errors" - "fmt" "io" - "io/ioutil" "net" - "net/http" - "time" ) var errShutdown = errors.New("proxy is shutting down") @@ -29,57 +24,6 @@ func isCloseable(err error) bool { return false } -// NewResponse builds a new HTTP response. -// If body is nil, an empty byte.Buffer will be provided to be consistent with -// the guarantees provided by http.Transport and http.Client. -func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { - if body == nil { - body = &bytes.Buffer{} - } - - rc, ok := body.(io.ReadCloser) - if !ok { - rc = ioutil.NopCloser(body) - } - - res := &http.Response{ - StatusCode: code, - Status: fmt.Sprintf("%d %s", code, http.StatusText(code)), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: http.Header{}, - Body: rc, - Request: req, - } - - if req != nil { - res.Close = req.Close - res.Proto = req.Proto - res.ProtoMajor = req.ProtoMajor - res.ProtoMinor = req.ProtoMinor - } - - return res -} - -// newErrorResponse creates a new HTTP response with status code 502 Bad Gateway -// "Warning" header is populated with the error details -// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Warning -func newErrorResponse(req *http.Request, err error) *http.Response { - res := NewResponse(http.StatusBadGateway, nil, req) - res.Close = true - - date := res.Header.Get("Date") - if date == "" { - date = time.Now().Format(http.TimeFormat) - } - - w := fmt.Sprintf(`199 "gomitmproxy" %q %q`, err.Error(), date) - res.Header.Add("Warning", w) - return res -} - // A peekedConn subverts the net.Conn.Read implementation, primarily so that // sniffed bytes can be transparently prepended. type peekedConn struct { diff --git a/proxy.go b/proxy.go index 43bc44e..a9a78ea 100644 --- a/proxy.go +++ b/proxy.go @@ -13,6 +13,8 @@ import ( "sync" "time" + "github.com/AdguardTeam/gomitmproxy/proxyutil" + "github.com/AdguardTeam/golibs/log" "github.com/pkg/errors" ) @@ -263,7 +265,7 @@ func (p *Proxy) handleRequest(ctx *Context) error { p.raiseOnError(session, err) // res body is closed below (see session.res.body.Close()) // nolint:bodyclose - res = newErrorResponse(session.req, err) + res = proxyutil.NewErrorResponse(session.req, err) if strings.Contains(err.Error(), "x509: ") || strings.Contains(err.Error(), errClientCertRequested.Error()) { @@ -308,7 +310,7 @@ func (p *Proxy) handleAPIRequest(session *Session) error { // nolint:bodyclose // body is actually closed - session.res = NewResponse(http.StatusOK, bytes.NewReader(b), session.req) + session.res = proxyutil.NewResponse(http.StatusOK, bytes.NewReader(b), session.req) defer session.res.Body.Close() session.res.Close = true session.res.Header.Set("Content-Type", "application/x-x509-ca-cert") @@ -318,7 +320,7 @@ func (p *Proxy) handleAPIRequest(session *Session) error { // nolint:bodyclose // body is actually closed - session.res = newErrorResponse(session.req, errors.Errorf("wrong API method")) + session.res = proxyutil.NewErrorResponse(session.req, errors.Errorf("wrong API method")) defer session.res.Body.Close() session.res.Close = true return p.writeResponse(session) @@ -360,7 +362,7 @@ func (p *Proxy) handleTunnel(session *Session) error { p.raiseOnError(session, err) // nolint:bodyclose // body is actually closed - session.res = newErrorResponse(session.req, err) + session.res = proxyutil.NewErrorResponse(session.req, err) _ = p.writeResponse(session) session.res.Body.Close() return err @@ -424,7 +426,7 @@ func (p *Proxy) handleConnect(session *Session) error { p.raiseOnError(session, err) // nolint:bodyclose // body is actually closed - session.res = newErrorResponse(session.req, err) + session.res = proxyutil.NewErrorResponse(session.req, err) _ = p.writeResponse(session) session.res.Body.Close() return err @@ -434,7 +436,7 @@ func (p *Proxy) handleConnect(session *Session) error { log.Debug("id=%s: attempting MITM for connection", session.ID()) // nolint:bodyclose // body is actually closed - session.res = NewResponse(http.StatusOK, nil, session.req) + session.res = proxyutil.NewResponse(http.StatusOK, nil, session.req) err = p.writeResponse(session) session.res.Body.Close() if err != nil { @@ -484,7 +486,7 @@ func (p *Proxy) handleConnect(session *Session) error { // nolint:bodyclose // body is actually closed - session.res = NewResponse(http.StatusOK, nil, session.req) + session.res = proxyutil.NewResponse(http.StatusOK, nil, session.req) defer session.res.Body.Close() session.res.ContentLength = -1 @@ -562,7 +564,8 @@ func (p *Proxy) writeResponse(session *Session) error { if p.OnResponse != nil { res := p.OnResponse(session) if res != nil { - defer res.Body.Close() + origBody := res.Body + defer origBody.Close() log.Debug("id=%s: response was overridden by: %s", session.ID(), res.Status) session.res = res } @@ -596,6 +599,11 @@ func (p *Proxy) prepareRequest(req *http.Request, session *Session) { req.URL.Scheme = "https" } req.RemoteAddr = session.ctx.conn.RemoteAddr().String() + + // remove unsupported encodings + if req.Header.Get("Accept-Encoding") != "" { + req.Header.Set("Accept-Encoding", "gzip") + } } // raiseOnError calls p.OnResponse diff --git a/proxyutil/util.go b/proxyutil/util.go new file mode 100644 index 0000000..9409a6d --- /dev/null +++ b/proxyutil/util.go @@ -0,0 +1,99 @@ +// Package proxyutil contains different utility methods that will +// be helpful to gomitmproxy users +package proxyutil + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "net/http" + "time" + + "golang.org/x/text/encoding/charmap" + "golang.org/x/text/transform" +) + +// NewResponse builds a new HTTP response. +// If body is nil, an empty byte.Buffer will be provided to be consistent with +// the guarantees provided by http.Transport and http.Client. +func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { + if body == nil { + body = &bytes.Buffer{} + } + + rc, ok := body.(io.ReadCloser) + if !ok { + rc = ioutil.NopCloser(body) + } + + res := &http.Response{ + StatusCode: code, + Status: fmt.Sprintf("%d %s", code, http.StatusText(code)), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Body: rc, + Request: req, + } + + if req != nil { + res.Close = req.Close + res.Proto = req.Proto + res.ProtoMajor = req.ProtoMajor + res.ProtoMinor = req.ProtoMinor + } + + return res +} + +// NewErrorResponse creates a new HTTP response with status code 502 Bad Gateway +// "Warning" header is populated with the error details +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Warning +func NewErrorResponse(req *http.Request, err error) *http.Response { + res := NewResponse(http.StatusBadGateway, nil, req) + res.Close = true + + date := res.Header.Get("Date") + if date == "" { + date = time.Now().Format(http.TimeFormat) + } + + w := fmt.Sprintf(`199 "gomitmproxy" %q %q`, err.Error(), date) + res.Header.Add("Warning", w) + return res +} + +// ReadDecompressedBody reads full response body and decompresses it if necessary +func ReadDecompressedBody(res *http.Response) ([]byte, error) { + rBody := res.Body + if res.Header.Get("Content-Encoding") == "gzip" { + gzReader, err := gzip.NewReader(rBody) + if err != nil { + return nil, err + } + rBody = gzReader + defer gzReader.Close() + } + return ioutil.ReadAll(rBody) +} + +// DecodeLatin1 - decodes Latin1 string from the reader +// This method is useful for editing response bodies when you don't want +// to handle different encodings +func DecodeLatin1(reader io.Reader) (string, error) { + r := transform.NewReader(reader, charmap.ISO8859_1.NewDecoder()) + b, err := ioutil.ReadAll(r) + if err != nil { + return "", err + } + + return string(b), nil +} + +// EncodeLatin1 - encodes the string as a byte array using Latin1 +func EncodeLatin1(str string) ([]byte, error) { + return charmap.ISO8859_1.NewEncoder().Bytes([]byte(str)) +}