From 57e8760e075846721285f38620c4ac285e0ee295 Mon Sep 17 00:00:00 2001 From: Tim Carter <112617649+tecarter94@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:11:00 +1100 Subject: [PATCH 1/2] Create go.yml --- .github/workflows/go.yml | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .github/workflows/go.yml diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..1cedc7f --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,31 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.23.4' + + - name: Install dependencies + run: go mod tidy + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v ./... From 997c62c05aaede1c6194013cb0c9e06995d8268b Mon Sep 17 00:00:00 2001 From: Tim Carter Date: Thu, 19 Dec 2024 16:40:25 +1100 Subject: [PATCH 2/2] Port domain proxy. --- .github/workflows/go.yml | 9 +- .gitignore | 17 +- cmd/client/main.go | 19 ++ cmd/server/main.go | 19 ++ deploy/Dockerfile.all-in-one | 15 ++ go.mod | 10 + go.sum | 8 + pkg/client/client.go | 110 +++++++++ pkg/common/common.go | 215 ++++++++++++++++++ pkg/server/server.go | 420 +++++++++++++++++++++++++++++++++++ test/domainproxy_test.go | 392 ++++++++++++++++++++++++++++++++ test/testdata/bar-1.0.pom | 9 + 12 files changed, 1236 insertions(+), 7 deletions(-) create mode 100644 cmd/client/main.go create mode 100644 cmd/server/main.go create mode 100644 deploy/Dockerfile.all-in-one create mode 100644 go.mod create mode 100644 go.sum create mode 100644 pkg/client/client.go create mode 100644 pkg/common/common.go create mode 100644 pkg/server/server.go create mode 100644 test/domainproxy_test.go create mode 100644 test/testdata/bar-1.0.pom diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 1cedc7f..1d11d0d 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -24,8 +24,11 @@ jobs: - name: Install dependencies run: go mod tidy - - name: Build - run: go build -v ./... + - name: Build Domain Proxy Server + run: go build -v -o bin/domainproxyserver cmd/server/main.go + + - name: Build Domain Proxy Client + run: go build -v -o bin/domainproxyclient cmd/client/main.go - name: Test - run: go test -v ./... + run: go test -v ./test diff --git a/.gitignore b/.gitignore index 6f72f89..919eab9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# # Binaries for programs and plugins *.exe *.exe~ @@ -15,7 +12,7 @@ *.out # Dependency directories (remove the comment below to include it) -# vendor/ +/vendor/ # Go workspace file go.work @@ -23,3 +20,15 @@ go.work.sum # env file .env + +# IntelliJ files +.idea/ +*.ipr +*.iml +*.iws + +# VS Code files +.vscode/ + +# Binary files +/bin/ \ No newline at end of file diff --git a/cmd/client/main.go b/cmd/client/main.go new file mode 100644 index 0000000..815b27d --- /dev/null +++ b/cmd/client/main.go @@ -0,0 +1,19 @@ +package main + +import ( + . "org.jboss.pnc.domain-proxy/pkg/client" + "os" + "os/signal" + "syscall" +) + +func main() { + domainProxyClient := NewDomainProxyClient() + ready := make(chan bool) + domainProxyClient.Start(ready) + <-ready + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + <-signals + domainProxyClient.Stop() +} diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..02faa18 --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,19 @@ +package main + +import ( + . "org.jboss.pnc.domain-proxy/pkg/server" + "os" + "os/signal" + "syscall" +) + +func main() { + domainProxyServer := NewDomainProxyServer() + ready := make(chan bool) + domainProxyServer.Start(ready) + <-ready + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + <-signals + domainProxyServer.Stop() +} diff --git a/deploy/Dockerfile.all-in-one b/deploy/Dockerfile.all-in-one new file mode 100644 index 0000000..9d0d53c --- /dev/null +++ b/deploy/Dockerfile.all-in-one @@ -0,0 +1,15 @@ +FROM registry.access.redhat.com/ubi9/go-toolset:1.22.5-1731639025@sha256:45170b6e45114849b5d2c0e55d730ffa4a709ddf5f58b9e810548097b085e78f as builder +USER 0 +WORKDIR /work +COPY ./ . + +RUN go mod tidy +RUN go build -o domainproxyserver cmd/server/main.go +RUN go build -o domainproxyclient cmd/client/main.go + +FROM quay.io/konflux-ci/buildah-task:latest@sha256:5cbd487022fb7ac476cbfdea25513b810f7e343ec48f89dc6a4e8c3c39fa37a2 +USER 0 +WORKDIR /work/ + +COPY --from=builder /work/domainproxyserver /app/domain-proxy-server +COPY --from=builder /work/domainproxyclient /app/domain-proxy-client diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8fb042a --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module org.jboss.pnc.domain-proxy + +go 1.23.4 + +require github.com/elazarl/goproxy v0.0.0-20241218172127-ac55c7698e0d + +require ( + golang.org/x/net v0.32.0 // indirect + golang.org/x/text v0.21.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..99248bc --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/elazarl/goproxy v0.0.0-20241218172127-ac55c7698e0d h1:r8DboPPvhhSMCWfmBEDoLuNvHetXH8/AZUdaRLNYgXE= +github.com/elazarl/goproxy v0.0.0-20241218172127-ac55c7698e0d/go.mod h1:3TKt+OFpElWuCtt5bphUyO97JT606j9Ffx4S2pfIcCo= +github.com/elazarl/goproxy/ext v0.0.0-20241217120900-7711dfa3811c h1:R+i10jtNSzKJKqEZAYJnR9M8y14k0zrNHqD1xkv/A2M= +github.com/elazarl/goproxy/ext v0.0.0-20241217120900-7711dfa3811c/go.mod h1:gNh8nYJoAm43RfaxurUnxr+N1PwuFV3ZMl/efxlIlY8= +golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= +golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= diff --git a/pkg/client/client.go b/pkg/client/client.go new file mode 100644 index 0000000..9453ff1 --- /dev/null +++ b/pkg/client/client.go @@ -0,0 +1,110 @@ +package client + +import ( + "fmt" + "net" + . "org.jboss.pnc.domain-proxy/pkg/common" + "time" +) + +const ( + Localhost = "localhost" + HttpPortKey = "DOMAIN_PROXY_HTTP_PORT" + DefaultHttpPort = 8080 + HttpToDomainSocket = "HTTP <-> Domain Socket" +) + +var logger = NewLogger("Domain Proxy Client") +var common = NewCommon(logger) + +type DomainProxyClient struct { + sharedParams *SharedParams + httpPort int +} + +func NewDomainProxyClient() *DomainProxyClient { + return &DomainProxyClient{ + sharedParams: common.NewSharedParams(), + httpPort: getHttpPort(), + } +} + +func (dpc *DomainProxyClient) Start(ready chan<- bool) { + sharedParams := dpc.sharedParams + logger.Println("Starting domain proxy client...") + var err error + sharedParams.Listener, err = net.Listen(TCP, fmt.Sprintf("%s:%d", Localhost, dpc.httpPort)) + if err != nil { + logger.Fatalf("Failed to start HTTP server: %v", err) + } + go dpc.startClient(ready) +} + +func (dpc *DomainProxyClient) startClient(ready chan<- bool) { + sharedParams := dpc.sharedParams + logger.Printf("HTTP server listening on port %d", dpc.httpPort) + ready <- true + for { + select { + case <-sharedParams.RunningContext.Done(): + return + default: + if serverConnection, err := sharedParams.Listener.Accept(); err != nil { + select { + case <-sharedParams.RunningContext.Done(): + return + default: + logger.Printf("Failed to accept server connection: %v", err) + } + } else { + go dpc.handleConnectionRequest(serverConnection) + } + } + } +} + +func (dpc *DomainProxyClient) handleConnectionRequest(serverConnection net.Conn) { + sharedParams := dpc.sharedParams + if err := serverConnection.SetDeadline(time.Now().Add(sharedParams.IdleTimeout)); err != nil { + common.HandleSetDeadlineError(serverConnection, err) + return + } + connectionNo := sharedParams.HttpConnectionCounter.Add(1) + logger.Printf("Handling %s Connection %d", HttpToDomainSocket, connectionNo) + startTime := time.Now() + domainConnection, err := net.DialTimeout(UNIX, sharedParams.DomainSocket, sharedParams.ConnectionTimeout) + if err != nil { + logger.Printf("Failed to connect to domain socket: %v", err) + if err = serverConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + if err := domainConnection.SetDeadline(time.Now().Add(sharedParams.IdleTimeout)); err != nil { + common.HandleSetDeadlineError(domainConnection, err) + if err = serverConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + // Initiate transfer between server and domain + go func() { + common.BiDirectionalTransfer(sharedParams.RunningContext, serverConnection, domainConnection, sharedParams.ByteBufferSize, HttpToDomainSocket, connectionNo) + logger.Printf("%s Connection %d ended after %d ms", HttpToDomainSocket, connectionNo, time.Since(startTime).Milliseconds()) + }() +} + +func (dpc *DomainProxyClient) Stop() { + sharedParams := dpc.sharedParams + logger.Println("Shutting down domain proxy client...") + sharedParams.InitiateShutdown() + if sharedParams.Listener != nil { + if err := sharedParams.Listener.Close(); err != nil { + common.HandleListenerCloseError(err) + } + } +} + +func getHttpPort() int { + return common.GetIntEnvVariable(HttpPortKey, DefaultHttpPort) +} diff --git a/pkg/common/common.go b/pkg/common/common.go new file mode 100644 index 0000000..8f6a380 --- /dev/null +++ b/pkg/common/common.go @@ -0,0 +1,215 @@ +package common + +import ( + "context" + "errors" + "io" + "log" + "net" + "os" + "strconv" + "strings" + "sync/atomic" + "time" +) + +const ( + ByteBufferSizeKey = "DOMAIN_PROXY_BYTE_BUFFER_SIZE" + DefaultByteBufferSize = 32768 + DomainSocketKey = "DOMAIN_PROXY_DOMAIN_SOCKET" + DefaultDomainSocket = "/tmp/domain-socket.sock" + ConnectionTimeoutKey = "DOMAIN_PROXY_CONNECTION_TIMEOUT" + DefaultConnectionTimeout = 10000 * time.Millisecond + IdleTimeoutKey = "DOMAIN_PROXY_IDLE_TIMEOUT" + DefaultIdleTimeout = 30000 * time.Millisecond + TCP = "tcp" + UNIX = "unix" +) + +type Common struct { + logger *log.Logger +} + +type SharedParams struct { + ByteBufferSize int + DomainSocket string + ConnectionTimeout time.Duration + IdleTimeout time.Duration + HttpConnectionCounter atomic.Uint64 + Listener net.Listener + RunningContext context.Context + InitiateShutdown context.CancelFunc +} + +func NewCommon(logger *log.Logger) *Common { + return &Common{ + logger: logger, + } +} + +func NewLogger(appName string) *log.Logger { + return log.New(os.Stdout, appName+" ", log.LstdFlags|log.Lshortfile) +} + +func (c *Common) NewSharedParams() *SharedParams { + runningContext, initiateShutdown := context.WithCancel(context.Background()) + return &SharedParams{ + ByteBufferSize: c.getByteBufferSize(), + DomainSocket: c.getDomainSocket(), + ConnectionTimeout: c.getConnectionTimeout(), + IdleTimeout: c.getIdleTimeout(), + RunningContext: runningContext, + InitiateShutdown: initiateShutdown, + } +} + +func (c *Common) BiDirectionalTransfer(runningContext context.Context, leftConnection, rightConnection net.Conn, byteBufferSize int, connectionType string, connectionNo uint64) { + defer c.CloseConnection(leftConnection, rightConnection, connectionType, connectionNo) + transferContext, terminateTransfer := context.WithCancel(runningContext) + go c.Transfer(transferContext, terminateTransfer, leftConnection, rightConnection, byteBufferSize, connectionType, connectionNo) + go c.Transfer(transferContext, terminateTransfer, rightConnection, leftConnection, byteBufferSize, connectionType, connectionNo) + <-transferContext.Done() +} + +func (c *Common) Transfer(transferContext context.Context, terminateTransfer context.CancelFunc, sourceConnection, targetConnection net.Conn, bufferSize int, connectionType string, connectionNo uint64) { + defer terminateTransfer() + buf := make([]byte, bufferSize) + for { + select { + case <-transferContext.Done(): + return + default: + if n, err := io.CopyBuffer(sourceConnection, targetConnection, buf); err != nil { + c.handleConnectionError(err, connectionType, connectionNo) + return + } else if n > 0 { + c.logger.Printf("%d bytes transferred for %s connection %d", n, connectionType, connectionNo) + } else { + // Nothing more to transfer + return + } + } + } +} + +func (c *Common) HandleSetDeadlineError(connection net.Conn, err error) { + c.logger.Printf("Failed to set deadline: %v", err) + if err = connection.Close(); err != nil { + c.HandleConnectionCloseError(err) + } +} + +func (c *Common) HandleConnectionCloseError(err error) { + c.logger.Printf("Failed to close connection: %v", err) +} + +func (c *Common) HandleListenerCloseError(err error) { + c.logger.Printf("Failed to close listener: %v", err) +} + +func (c *Common) handleConnectionError(err error, connectionType string, connectionNo uint64) { + var netErr net.Error + if !errors.Is(err, net.ErrClosed) { // We don't care if connection has been closed, because this is expected + if errors.As(err, &netErr) && netErr.Timeout() { + c.logger.Printf("%s connection %d timed out", connectionType, connectionNo) + } else if err != io.EOF { + c.logger.Printf("Failed to transfer data using %s connection %d: %v", connectionType, connectionNo, err) + } + } +} + +func (c *Common) CloseConnection(leftConnection, rightConnection net.Conn, connectionType string, connectionNo uint64) { + if err := leftConnection.Close(); err != nil { + c.HandleConnectionCloseError(err) + } + if err := rightConnection.Close(); err != nil { + c.HandleConnectionCloseError(err) + } + c.logger.Printf("%s connection %d closed", connectionType, connectionNo) +} + +func (c *Common) GetEnvVariable(key, defaultValue string) string { + value := os.Getenv(key) + if value == "" { + c.logger.Printf("Environment variable %s is not set, using default value: %s", key, defaultValue) + return defaultValue + } + return value +} + +func (c *Common) GetIntEnvVariable(key string, defaultValue int) int { + valueStr := os.Getenv(key) + if valueStr == "" { + c.logger.Printf("Environment variable %s is not set, using default value: %d", key, defaultValue) + return defaultValue + } + value, err := strconv.Atoi(valueStr) + if err != nil { + c.logger.Printf("Invalid environment variable %s: %v, using default value: %d", key, err, defaultValue) + return defaultValue + } + return value +} + +func (c *Common) GetCsvEnvVariable(key, defaultValue string) map[string]bool { + valuesStr := os.Getenv(key) + if valuesStr == "" { + c.logger.Printf("Environment variable %s is not set, using default value: %s", key, defaultValue) + return c.parseCsvToMap(defaultValue) + } + return c.parseCsvToMap(valuesStr) +} + +func (c *Common) GetMillisecondsEnvVariable(key string, defaultValue time.Duration) time.Duration { + valueStr := os.Getenv(key) + if valueStr == "" { + c.logger.Printf("Environment variable %s is not set, using default value: %d", key, defaultValue.Milliseconds()) + return defaultValue + } + value, err := strconv.Atoi(valueStr) + if err != nil { + c.logger.Printf("Invalid environment variable %s: %v, using default value: %d", key, err, defaultValue.Milliseconds()) + return defaultValue + } + return time.Duration(value) * time.Millisecond +} + +func (c *Common) parseCsvToMap(csvString string) map[string]bool { + valuesStr := strings.Split(csvString, ",") + values := make(map[string]bool) + for _, value := range valuesStr { + trimmedValue := strings.TrimSpace(value) + values[trimmedValue] = true + } + return values +} + +func (c *Common) GetBoolEnvVariable(key string, defaultValue bool) bool { + valueStr := os.Getenv(key) + if valueStr == "" { + c.logger.Printf("Environment variable %s is not set, using default value: %t", key, defaultValue) + return defaultValue + } + value, err := strconv.ParseBool(valueStr) + if err != nil { + c.logger.Printf("Invalid environment variable %s: %v, using default value: %t", key, err, defaultValue) + return defaultValue + } + return value +} + +func (c *Common) getByteBufferSize() int { + return c.GetIntEnvVariable(ByteBufferSizeKey, DefaultByteBufferSize) +} + +func (c *Common) getDomainSocket() string { + return c.GetEnvVariable(DomainSocketKey, DefaultDomainSocket) +} + +func (c *Common) getConnectionTimeout() time.Duration { + return c.GetMillisecondsEnvVariable(ConnectionTimeoutKey, DefaultConnectionTimeout) +} + +func (c *Common) getIdleTimeout() time.Duration { + return c.GetMillisecondsEnvVariable(IdleTimeoutKey, DefaultIdleTimeout) +} diff --git a/pkg/server/server.go b/pkg/server/server.go new file mode 100644 index 0000000..2f4bb14 --- /dev/null +++ b/pkg/server/server.go @@ -0,0 +1,420 @@ +package server + +import ( + "bufio" + "encoding/base64" + "fmt" + "net" + "net/http" + . "org.jboss.pnc.domain-proxy/pkg/common" + "os" + "strconv" + "strings" + "sync/atomic" + "time" +) + +const ( + HttpPort = 80 + HttpsPort = 443 + TargetWhitelistKey = "DOMAIN_PROXY_TARGET_WHITELIST" + DefaultTargetWhitelist = "localhost,repo.maven.apache.org,repository.jboss.org,packages.confluent.io,jitpack.io,repo.gradle.org,plugins.gradle.org" + EnableInternalProxyKey = "DOMAIN_PROXY_ENABLE_INTERNAL_PROXY" + DefaultEnableInternalProxy = false + InternalProxyHostKey = "DOMAIN_PROXY_INTERNAL_PROXY_HOST" + DefaultInternalProxyHost = "indy-generic-proxy" + InternalProxyPortKey = "DOMAIN_PROXY_INTERNAL_PROXY_PORT" + DefaultInternalProxyPort = 80 + InternalProxyUserKey = "DOMAIN_PROXY_INTERNAL_PROXY_USER" + DefaultInternalProxyUser = "" + InternalProxyPasswordKey = "DOMAIN_PROXY_INTERNAL_PROXY_PASSWORD" + DefaultInternalProxyPassword = "" + InternalNonProxyHostsKey = "DOMAIN_PROXY_INTERNAL_NON_PROXY_HOSTS" + DefaultInternalNonProxyHosts = "localhost" + DomainSocketToHttp = "Domain Socket <-> HTTP" + DomainSocketToHttps = "Domain Socket <-> HTTPS" +) + +var logger = NewLogger("Domain Proxy Server") +var common = NewCommon(logger) + +type DomainProxyServer struct { + sharedParams *SharedParams + targetWhitelist map[string]bool + enableInternalProxy bool + internalProxyHost string + internalProxyPort int + internalProxyUser string + internalProxyPassword string + internalNonProxyHosts map[string]bool + httpsConnectionCounter atomic.Uint64 +} + +func NewDomainProxyServer() *DomainProxyServer { + return &DomainProxyServer{ + sharedParams: common.NewSharedParams(), + targetWhitelist: getTargetWhitelist(), + enableInternalProxy: getEnableInternalProxy(), + internalProxyHost: getInternalProxyHost(), + internalProxyPort: getInternalProxyPort(), + internalProxyUser: getInternalProxyUser(), + internalProxyPassword: getInternalProxyPassword(), + internalNonProxyHosts: getInternalNonProxyHosts(), + } +} + +func (dps *DomainProxyServer) Start(ready chan<- bool) { + sharedParams := dps.sharedParams + logger.Println("Starting domain proxy server...") + if _, err := os.Stat(sharedParams.DomainSocket); err == nil { + if err := os.Remove(sharedParams.DomainSocket); err != nil { + logger.Fatalf("Failed to delete existing domain socket: %v", err) + } + } + var err error + sharedParams.Listener, err = net.Listen(UNIX, sharedParams.DomainSocket) + if err != nil { + logger.Fatalf("Failed to start domain socket listener: %v", err) + } + go dps.startServer(ready) +} + +func (dps *DomainProxyServer) startServer(ready chan<- bool) { + sharedParams := dps.sharedParams + logger.Printf("Domain socket server listening on %s", sharedParams.DomainSocket) + ready <- true + for { + select { + case <-sharedParams.RunningContext.Done(): + return + default: + if domainConnection, err := sharedParams.Listener.Accept(); err != nil { + select { + case <-sharedParams.RunningContext.Done(): + return + default: + logger.Printf("Failed to accept domain socket connection: %v", err) + } + } else { + go dps.handleConnectionRequest(domainConnection) + } + } + } +} + +func (dps *DomainProxyServer) handleConnectionRequest(domainConnection net.Conn) { + sharedParams := dps.sharedParams + if err := domainConnection.SetDeadline(time.Now().Add(sharedParams.IdleTimeout)); err != nil { + common.HandleSetDeadlineError(domainConnection, err) + return + } + reader := bufio.NewReader(domainConnection) + request, err := http.ReadRequest(reader) + if err != nil { + logger.Printf("Failed to read request: %v", err) + if err = domainConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + writer := &responseWriter{connection: domainConnection} + if request.Method == http.MethodConnect { + dps.handleHttpsConnection(domainConnection, writer, request) + } else { + dps.handleHttpConnection(domainConnection, writer, request) + } +} + +func (dps *DomainProxyServer) handleHttpConnection(sourceConnection net.Conn, writer http.ResponseWriter, request *http.Request) { + sharedParams := dps.sharedParams + connectionNo := sharedParams.HttpConnectionCounter.Add(1) + targetHost, targetPort := getTargetHostAndPort(request.Host, HttpPort) + actualTargetHost, actualTargetPort := targetHost, targetPort + targetConnectionName := "target" + useInternalProxy := dps.useInternalProxy(targetHost) + // Redirect connection to internal proxy if enabled + if useInternalProxy { + targetHost, targetPort = dps.internalProxyHost, dps.internalProxyPort + logger.Printf("Handling %s Connection %d with internal proxy %s:%d and target %s:%d", DomainSocketToHttp, connectionNo, targetHost, targetPort, actualTargetHost, actualTargetPort) + targetConnectionName = "internal proxy" + } else { + logger.Printf("Handling %s Connection %d with target %s:%d", DomainSocketToHttp, connectionNo, actualTargetHost, actualTargetPort) + } + // Check if target is whitelisted + if !dps.isTargetWhitelisted(actualTargetHost, writer) { + if err := sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + startTime := time.Now() + request.Header.Del("Proxy-Connection") // Prevent keep-alive as it breaks internal proxy authentication + request.Header.Set("Connection", "close") // Prevent keep-alive as it breaks internal proxy authentication + // Update request with target details for internal proxy if enabled + if useInternalProxy { + request.Header.Set("Host", fmt.Sprintf("%s:%d", actualTargetHost, actualTargetPort)) + // Add authentication details if configured + if dps.internalProxyUser != "" && dps.internalProxyPassword != "" { + request.Header.Set("Proxy-Authorization", "Basic "+GetBasicAuth(dps.internalProxyUser, dps.internalProxyPassword)) + } + } + // Try to connect to target or internal proxy + targetConnection, err := net.DialTimeout(TCP, fmt.Sprintf("%s:%d", targetHost, targetPort), sharedParams.ConnectionTimeout) + if err != nil { + dps.handleErrorResponse(writer, err, fmt.Sprintf("Failed to connect to %s", targetConnectionName)) + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + if err = targetConnection.SetDeadline(time.Now().Add(sharedParams.IdleTimeout)); err != nil { + common.HandleSetDeadlineError(targetConnection, err) + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + // Send HTTP request to internal proxy if enabled + if useInternalProxy { + err = request.WriteProxy(targetConnection) + } else { + err = request.Write(targetConnection) + } + if err != nil { + dps.handleErrorResponse(writer, err, fmt.Sprintf("Failed to send request to %s", targetConnectionName)) + if err = targetConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + // Initiate transfer between source and target or internal proxy + go func() { + common.BiDirectionalTransfer(sharedParams.RunningContext, sourceConnection, targetConnection, sharedParams.ByteBufferSize, DomainSocketToHttp, connectionNo) + logger.Printf("%s Connection %d ended after %d ms", DomainSocketToHttp, connectionNo, time.Since(startTime).Milliseconds()) + }() +} + +func (dps *DomainProxyServer) handleHttpsConnection(sourceConnection net.Conn, writer http.ResponseWriter, request *http.Request) { + sharedParams := dps.sharedParams + connectionNo := dps.httpsConnectionCounter.Add(1) + targetHost, targetPort := getTargetHostAndPort(request.Host, HttpsPort) + actualTargetHost, actualTargetPort := targetHost, targetPort + targetConnectionName := "target" + useInternalProxy := dps.useInternalProxy(targetHost) + // Redirect connection to internal proxy if enabled + if useInternalProxy { + targetHost, targetPort = dps.internalProxyHost, dps.internalProxyPort + logger.Printf("Handling %s Connection %d with internal proxy %s:%d and target %s:%d", DomainSocketToHttps, connectionNo, targetHost, targetPort, actualTargetHost, actualTargetPort) + targetConnectionName = "internal proxy" + } else { + logger.Printf("Handling %s Connection %d with target %s:%d", DomainSocketToHttps, connectionNo, actualTargetHost, actualTargetPort) + } + // Check if target is whitelisted + if !dps.isTargetWhitelisted(actualTargetHost, writer) { + if err := sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + startTime := time.Now() + request.Header.Del("Proxy-Connection") // Prevent keep-alive as it breaks internal proxy authentication + request.Header.Set("Connection", "close") // Prevent keep-alive as it breaks internal proxy authentication + // Try to connect to target or internal proxy + targetConnection, err := net.DialTimeout(TCP, fmt.Sprintf("%s:%d", targetHost, targetPort), sharedParams.ConnectionTimeout) + if err != nil { + dps.handleErrorResponse(writer, err, fmt.Sprintf("Failed to connect to %s", targetConnectionName)) + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + if err = targetConnection.SetDeadline(time.Now().Add(sharedParams.IdleTimeout)); err != nil { + common.HandleSetDeadlineError(targetConnection, err) + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + // Create HTTPS connection to internal proxy if enabled + if useInternalProxy { + proxyConnectRequest := fmt.Sprintf("CONNECT %s:%d HTTP/1.1\r\nHost: %s:%d\r\nConnection: close\r\n", actualTargetHost, actualTargetPort, actualTargetHost, actualTargetPort) // Prevent keep-alive as it breaks internal proxy authentication + // Add authentication details if configured + if dps.internalProxyUser != "" && dps.internalProxyPassword != "" { + proxyConnectRequest += fmt.Sprintf("Proxy-Authorization: Basic %s\r\n", GetBasicAuth(dps.internalProxyUser, dps.internalProxyPassword)) + } + proxyConnectRequest += "\r\n" + if _, err = targetConnection.Write([]byte(proxyConnectRequest)); err != nil { + dps.handleErrorResponse(writer, err, "Failed to send connect request to internal proxy") + if err = targetConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + proxyReader := bufio.NewReader(targetConnection) + proxyResponse, err := http.ReadResponse(proxyReader, request) + if err != nil { + dps.handleErrorResponse(writer, err, "Failed to establish connection with internal proxy") + if err = targetConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } else if proxyResponse.StatusCode != http.StatusOK { + proxyResponse.Header.Set("Connection", "close") // Prevent keep-alive as it breaks internal proxy authentication + if err := proxyResponse.Write(sourceConnection); err != nil { + dps.handleErrorResponse(writer, err, "Failed to send internal proxy response to source") + } + if err = targetConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + } + // Notify source that HTTPS connection has been established to target or internal proxy + if _, err = writer.Write([]byte("HTTP/1.1 200 Connection Established\r\nConnection: close\r\n\r\n")); err != nil { // Prevent keep-alive as it breaks internal proxy authentication + dps.handleErrorResponse(writer, err, "Failed to send connect response to source") + if err = targetConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + if err = sourceConnection.Close(); err != nil { + common.HandleConnectionCloseError(err) + } + return + } + // Initiate transfer between source and target or internal proxy + go func() { + common.BiDirectionalTransfer(sharedParams.RunningContext, sourceConnection, targetConnection, sharedParams.ByteBufferSize, DomainSocketToHttps, connectionNo) + logger.Printf("%s Connection %d ended after %d ms", DomainSocketToHttps, connectionNo, time.Since(startTime).Milliseconds()) + }() +} + +func getTargetHostAndPort(host string, defaultPort int) (string, int) { + hostAndPort := strings.Split(host, ":") + targetHost := hostAndPort[0] + targetPort := defaultPort + if len(hostAndPort) > 1 { + if port, err := strconv.Atoi(hostAndPort[1]); err == nil { + targetPort = port + } + } + return targetHost, targetPort +} + +func (dps *DomainProxyServer) isTargetWhitelisted(targetHost string, writer http.ResponseWriter) bool { + if !dps.targetWhitelist[targetHost] { + message := fmt.Sprintf("Target host %s is not whitelisted", targetHost) + logger.Println(message) + http.Error(writer, message, http.StatusForbidden) + return false + } + return true +} + +func (dps *DomainProxyServer) useInternalProxy(targetHost string) bool { + if dps.enableInternalProxy { + if !dps.internalNonProxyHosts[targetHost] { + return true + } else { + logger.Printf("Target host %s is non-proxy host", targetHost) + } + } + return false +} + +func GetBasicAuth(user string, password string) string { + return base64.StdEncoding.EncodeToString([]byte(user + ":" + password)) +} + +func (dps *DomainProxyServer) handleErrorResponse(writer http.ResponseWriter, err error, message string) { + logger.Printf("%s: %v", message, err) + writer.Header().Set("Connection", "close") // Prevent keep-alive as it breaks internal proxy authentication + status := http.StatusInternalServerError + http.Error(writer, message+": "+err.Error(), status) +} + +func (dps *DomainProxyServer) Stop() { + sharedParams := dps.sharedParams + logger.Println("Shutting down domain proxy server...") + sharedParams.InitiateShutdown() + if sharedParams.Listener != nil { + if err := sharedParams.Listener.Close(); err != nil { + common.HandleListenerCloseError(err) + } + } + if _, err := os.Stat(sharedParams.DomainSocket); err == nil { + if err := os.Remove(sharedParams.DomainSocket); err != nil { + logger.Printf("Failed to delete domain socket: %v", err) + } + } +} + +type responseWriter struct { + connection net.Conn + header http.Header + statusCode int +} + +func (rw *responseWriter) Header() http.Header { + if rw.header == nil { + rw.header = make(http.Header) + } + return rw.header +} + +func (rw *responseWriter) Write(data []byte) (int, error) { + return rw.connection.Write(data) +} + +func (rw *responseWriter) WriteHeader(statusCode int) { + rw.statusCode = statusCode + headers := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) + headers += "Connection: close\r\n" // Prevent keep-alive as it breaks internal proxy authentication + for k, v := range rw.Header() { + for _, vv := range v { + headers += fmt.Sprintf("%s: %s\r\n", k, vv) + } + } + headers += "\r\n" + if _, err := rw.connection.Write([]byte(headers)); err != nil { + logger.Printf("Failed to write headers to connection: %v", err) + } +} + +func getTargetWhitelist() map[string]bool { + return common.GetCsvEnvVariable(TargetWhitelistKey, DefaultTargetWhitelist) +} + +func getEnableInternalProxy() bool { + return common.GetBoolEnvVariable(EnableInternalProxyKey, DefaultEnableInternalProxy) +} + +func getInternalProxyHost() string { + return common.GetEnvVariable(InternalProxyHostKey, DefaultInternalProxyHost) +} + +func getInternalProxyPort() int { + return common.GetIntEnvVariable(InternalProxyPortKey, DefaultInternalProxyPort) +} + +func getInternalProxyUser() string { + return common.GetEnvVariable(InternalProxyUserKey, DefaultInternalProxyUser) +} + +func getInternalProxyPassword() string { + return common.GetEnvVariable(InternalProxyPasswordKey, DefaultInternalProxyPassword) +} + +func getInternalNonProxyHosts() map[string]bool { + return common.GetCsvEnvVariable(InternalNonProxyHostsKey, DefaultInternalNonProxyHosts) +} diff --git a/test/domainproxy_test.go b/test/domainproxy_test.go new file mode 100644 index 0000000..ede8348 --- /dev/null +++ b/test/domainproxy_test.go @@ -0,0 +1,392 @@ +package integration + +import ( + "crypto/md5" + "crypto/tls" + "encoding/hex" + "errors" + "fmt" + "github.com/elazarl/goproxy" + "io" + "math/rand" + "net/http" + "net/http/httptest" + "net/url" + . "org.jboss.pnc.domain-proxy/pkg/client" + . "org.jboss.pnc.domain-proxy/pkg/common" + . "org.jboss.pnc.domain-proxy/pkg/server" + "os" + "strconv" + "strings" + "testing" +) + +const ( + DomainProxyPort = "8081" + InternalProxyPort = "8082" + DomainProxyUrl = "http://" + Localhost + ":" + DomainProxyPort + ContentType = "text/xml" + Md5Hash = "ea3ca57f8f99d1d210d1b438c9841440" + ContentLength = "403" + MockUrlPath = "/com/foo/bar/1.0/bar-1.0.pom" + NonExistentUrlPath = "/com/foo/bar/1.0/bar-2.0.pom" + NonWhitelistedUrl = "repo1.maven.org/maven2/org/apache/maven/plugins/maven-jar-plugin/3.4.1/maven-jar-plugin-3.4.1.jar" + NonExistentHost = "foo.bar" + User = "foo" + Password = "bar" +) + +func createClient(t *testing.T) *http.Client { + proxyUrl, err := url.Parse(DomainProxyUrl) + if err != nil { + t.Fatal(err) + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + return &http.Client{ + Transport: transport, + } +} + +func getMd5Hash(bytes []byte) string { + hash := md5.Sum(bytes) + return hex.EncodeToString(hash[:]) +} + +func getRandomDomainSocket() string { + return "/tmp/domain-socket-" + strconv.Itoa(rand.Int()) + ".sock" +} + +func mockHandler(t *testing.T) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && r.URL.Path == MockUrlPath { + // Mock GET response + pom, err := os.ReadFile("testdata/bar-1.0.pom") + if err != nil { + t.Fatal(err) + } + w.Header().Set("Content-Type", ContentType) + w.WriteHeader(http.StatusOK) + if _, err := w.Write(pom); err != nil { + t.Fatal(err) + } + } else if r.Method == http.MethodHead && r.URL.Path == MockUrlPath { + // Mock HEAD response + w.Header().Set("Content-Type", ContentType) + w.Header().Set("Content-Length", ContentLength) + w.WriteHeader(http.StatusOK) + } else { + http.NotFound(w, r) + } + } +} + +func startDomainProxy() (*DomainProxyServer, *DomainProxyClient) { + domainProxyServer := NewDomainProxyServer() + serverReady := make(chan bool) + go domainProxyServer.Start(serverReady) + <-serverReady + clientReady := make(chan bool) + domainProxyClient := NewDomainProxyClient() + go domainProxyClient.Start(clientReady) + <-clientReady + return domainProxyServer, domainProxyClient +} + +func stopDomainProxy(domainProxyServer *DomainProxyServer, domainProxyClient *DomainProxyClient) { + domainProxyServer.Stop() + domainProxyClient.Stop() +} + +func startMockServers(t *testing.T) (*httptest.Server, *httptest.Server) { + mockHandler := mockHandler(t) + mockHttpServer := httptest.NewServer(mockHandler) + mockHttpsServer := httptest.NewUnstartedServer(mockHandler) + mockHttpsServer.StartTLS() + return mockHttpServer, mockHttpsServer +} + +func stopMockServers(mockHttpServer *httptest.Server, mockHttpsServer *httptest.Server) { + mockHttpServer.Close() + mockHttpsServer.Close() +} + +func startInternalProxyServer(t *testing.T, onRequestFunction func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response), onConnectFunction func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string)) *http.Server { + internalProxy := goproxy.NewProxyHttpServer() + internalProxy.Verbose = true + if onRequestFunction != nil { + internalProxy.OnRequest().DoFunc(onRequestFunction) + internalProxy.OnRequest().HandleConnectFunc(onConnectFunction) + } + internalProxyServer := &http.Server{ + Addr: Localhost + ":" + InternalProxyPort, + Handler: internalProxy, + } + go func() { + if err := internalProxyServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Error(err) + } + }() + return internalProxyServer +} + +func stopInternalProxyServer(t *testing.T, internalProxyServer *http.Server) { + err := internalProxyServer.Close() + if err != nil { + t.Fatal(err) + } +} + +func commonTestBehaviour(t *testing.T, qualifier string) { + // Set env variables + t.Setenv(DomainSocketKey, getRandomDomainSocket()) + t.Setenv(HttpPortKey, DomainProxyPort) + t.Setenv(TargetWhitelistKey, "127.0.0.1,foo.bar") + // Start services + domainProxyServer, domainProxyClient := startDomainProxy() + defer stopDomainProxy(domainProxyServer, domainProxyClient) + // Start mock HTTP and HTTPS servers + mockHttpServer, mockHttpsServer := startMockServers(t) + defer stopMockServers(mockHttpServer, mockHttpsServer) + mockHttpUrl := mockHttpServer.URL + mockHttpsUrl := mockHttpsServer.URL + // Create HTTP client + httpClient := createClient(t) + + t.Run(fmt.Sprintf("Test HTTP GET dependency%s", qualifier), func(t *testing.T) { + response, err := httpClient.Get(mockHttpUrl + MockUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusOK) + } + pom, err := io.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + hash := getMd5Hash(pom) + if hash != Md5Hash { + t.Fatalf("Actual MD5 hash %s did not match expected MD5 hash %s", hash, Md5Hash) + } + }) + + t.Run(fmt.Sprintf("Test HTTPS GET dependency%s", qualifier), func(t *testing.T) { + response, err := httpClient.Get(mockHttpsUrl + MockUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusOK) + } + pom, err := io.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + hash := getMd5Hash(pom) + if hash != Md5Hash { + t.Fatalf("Actual MD5 hash %s did not match expected MD5 hash %s", hash, Md5Hash) + } + }) + + t.Run(fmt.Sprintf("Test HTTP GET non-existent dependency%s", qualifier), func(t *testing.T) { + response, err := httpClient.Get(mockHttpUrl + NonExistentUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusNotFound { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusNotFound) + } + }) + + t.Run(fmt.Sprintf("Test HTTPS GET non-existent dependency%s", qualifier), func(t *testing.T) { + response, err := httpClient.Get(mockHttpsUrl + NonExistentUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusNotFound { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusNotFound) + } + }) + + t.Run(fmt.Sprintf("Test HTTP non-whitelisted host%s", qualifier), func(t *testing.T) { + response, err := httpClient.Get("http://" + NonWhitelistedUrl) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusForbidden { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusForbidden) + } + }) + + t.Run(fmt.Sprintf("Test HTTPS non-whitelisted host%s", qualifier), func(t *testing.T) { + _, err := httpClient.Get("https://" + NonWhitelistedUrl) + statusText := http.StatusText(http.StatusForbidden) + if !strings.Contains(err.Error(), statusText) { + t.Fatalf("Actual error %s did not contain expected HTTP status text %s", err.Error(), statusText) + } + }) + + t.Run(fmt.Sprintf("Test HTTP non-existent host%s", qualifier), func(t *testing.T) { + response, err := httpClient.Get("http://" + NonExistentHost) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusInternalServerError { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusInternalServerError) + } + }) + + t.Run(fmt.Sprintf("Test HTTPS non-existent host%s", qualifier), func(t *testing.T) { + _, err := httpClient.Get("https://" + NonExistentHost) + internalServerStatusText := http.StatusText(http.StatusInternalServerError) + badGatewayStatusText := http.StatusText(http.StatusBadGateway) + if !strings.Contains(err.Error(), internalServerStatusText) && !strings.Contains(err.Error(), badGatewayStatusText) { // Internal proxy may return 502 Bad Gateway + t.Fatalf("Actual error %s did not contain expected HTTP status text %s or %s", err.Error(), internalServerStatusText, badGatewayStatusText) + } + }) + + t.Run(fmt.Sprintf("Test HTTP HEAD dependency%s", qualifier), func(t *testing.T) { + response, err := httpClient.Head(mockHttpUrl + MockUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + actualContentLength := response.Header.Get("Content-Length") + if actualContentLength != ContentLength { + t.Fatalf("Actual content length %s did not match expected content length %s", actualContentLength, ContentLength) + } + }) + + t.Run(fmt.Sprintf("Test HTTPS HEAD dependency%s", qualifier), func(t *testing.T) { + response, err := httpClient.Head(mockHttpsUrl + MockUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + actualContentLength := response.Header.Get("Content-Length") + if actualContentLength != ContentLength { + t.Fatalf("Actual content length %s did not match expected content length %s", actualContentLength, ContentLength) + } + }) + + t.Run(fmt.Sprintf("Test HTTP HEAD non-existent dependency%s", qualifier), func(t *testing.T) { + response, err := httpClient.Head(mockHttpUrl + NonExistentUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusNotFound { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusNotFound) + } + }) + + t.Run(fmt.Sprintf("Test HTTPS HEAD non-existent dependency%s", qualifier), func(t *testing.T) { + response, err := httpClient.Head(mockHttpsUrl + NonExistentUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusNotFound { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusNotFound) + } + }) +} + +func commonInternalProxyTestBehaviour(t *testing.T, qualifier string, onRequestFunction func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response), onConnectFunction func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string)) { + // Start internal proxy + internalProxyServer := startInternalProxyServer(t, onRequestFunction, onConnectFunction) + // Set env variables + t.Setenv(EnableInternalProxyKey, "true") + t.Setenv(InternalProxyHostKey, Localhost) + t.Setenv(InternalProxyPortKey, InternalProxyPort) + t.Setenv(InternalNonProxyHostsKey, "example.com") + // Run tests with internal proxy + commonTestBehaviour(t, qualifier) + // Stop internal proxy + stopInternalProxyServer(t, internalProxyServer) + // Set non-proxy hosts env variable + t.Setenv(InternalNonProxyHostsKey, "127.0.0.1,foo.bar") + // Run tests without internal proxy + commonTestBehaviour(t, qualifier+" and non-proxy host") +} + +func TestDomainProxy(t *testing.T) { + commonTestBehaviour(t, "") +} + +func TestDomainProxyWithInternalProxy(t *testing.T) { + commonInternalProxyTestBehaviour(t, " with internal proxy", nil, nil) +} + +func TestDomainProxyWithInternalProxyAndAuthentication(t *testing.T) { + // Set env variables + t.Setenv(InternalProxyUserKey, User) + t.Setenv(InternalProxyPasswordKey, Password) + basicAuth := "Basic " + GetBasicAuth(User, Password) + // Create internal proxy HTTP authentication handler + onRequestFunction := func(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { + if req.Header.Get("Proxy-Authorization") != basicAuth { + return nil, goproxy.NewResponse(req, goproxy.ContentTypeText, http.StatusProxyAuthRequired, http.StatusText(http.StatusProxyAuthRequired)) + } + return req, nil + } + // Create internal proxy HTTPS authentication handler + onConnectionFunction := func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { + req := ctx.Req + authHeader := req.Header.Get("Proxy-Authorization") + if authHeader != basicAuth { + ctx.Resp = goproxy.NewResponse(req, goproxy.ContentTypeText, http.StatusProxyAuthRequired, http.StatusText(http.StatusProxyAuthRequired)) + return goproxy.RejectConnect, host + } + return goproxy.OkConnect, host + } + // Run tests with internal proxy and authentication + commonInternalProxyTestBehaviour(t, " with internal proxy and authentication", onRequestFunction, onConnectionFunction) + + // Set invalid authentication env variables + t.Setenv(DomainSocketKey, getRandomDomainSocket()) + t.Setenv(InternalProxyUserKey, "123") + t.Setenv(InternalProxyPasswordKey, "456") + t.Setenv(InternalNonProxyHostsKey, "example.com") + // Start internal proxy + internalProxyServer := startInternalProxyServer(t, onRequestFunction, onConnectionFunction) + defer stopInternalProxyServer(t, internalProxyServer) + // Start services + domainProxyServer, domainProxyClient := startDomainProxy() + defer stopDomainProxy(domainProxyServer, domainProxyClient) + // Start mock HTTP and HTTPS servers + mockHttpServer, mockHttpsServer := startMockServers(t) + defer stopMockServers(mockHttpServer, mockHttpsServer) + mockHttpUrl := mockHttpServer.URL + mockHttpsUrl := mockHttpsServer.URL + // Create HTTP client + httpClient := createClient(t) + + t.Run("Test HTTP GET dependency with internal proxy and invalid authentication", func(t *testing.T) { + response, err := httpClient.Get(mockHttpUrl + MockUrlPath) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusProxyAuthRequired { + t.Fatalf("Actual HTTP status %d did not match expected HTTP status %d", response.StatusCode, http.StatusProxyAuthRequired) + } + }) + + t.Run("Test HTTPS GET dependency with internal proxy and invalid authentication", func(t *testing.T) { + _, err := httpClient.Get(mockHttpsUrl + MockUrlPath) + statusText := http.StatusText(http.StatusProxyAuthRequired) + if !strings.Contains(err.Error(), statusText) { + t.Fatalf("Actual error %s did not contain expected HTTP status text %s", err.Error(), statusText) + } + }) +} diff --git a/test/testdata/bar-1.0.pom b/test/testdata/bar-1.0.pom new file mode 100644 index 0000000..45f65a9 --- /dev/null +++ b/test/testdata/bar-1.0.pom @@ -0,0 +1,9 @@ + + + 4.0.0 + com.foo + bar + 1.0 +