diff --git a/cmd/repeater/repeater.go b/cmd/repeater/repeater.go index 7ba1621..663230c 100644 --- a/cmd/repeater/repeater.go +++ b/cmd/repeater/repeater.go @@ -78,9 +78,16 @@ func main() { TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, + Proxy: http.ProxyFromEnvironment, }, } } + + disableRedirectsEnv := os.Getenv("ESCAPE_REPEATER_DISABLE_REDIRECTS") + if disableRedirectsEnv == "1" || disableRedirectsEnv == "true" { + roundtrip.DisableRedirects = true + } + logger.Info("Starting repeater client...") go logger.AlwaysConnect(url, repeaterId) diff --git a/pkg/roundtrip/roudtrip.go b/pkg/roundtrip/roudtrip.go index b204e55..705ebaa 100644 --- a/pkg/roundtrip/roudtrip.go +++ b/pkg/roundtrip/roudtrip.go @@ -9,6 +9,7 @@ import ( var DefaultClient = &http.Client{} var MTLSClient *http.Client = nil +var DisableRedirects = false const mTLSHeader = "X-Escape-mTLS" @@ -46,6 +47,15 @@ func HandleRequest(protoReq *proto.Request) *proto.Response { tls(protoReq.Url) } client := DefaultClient + + if httpReq.Header.Get("X-Disable-Redirects") == "true" || DisableRedirects { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } else { + client.CheckRedirect = nil + } + mTLS := false if httpReq.Header.Get(mTLSHeader) != "" { if MTLSClient != nil {