From 15002b522938f5253e2cd424fb83b5e0263f92dc Mon Sep 17 00:00:00 2001 From: Suraj Shirvankar Date: Mon, 5 Aug 2024 19:26:26 +0200 Subject: [PATCH] refactor: unify http client creation Signed-off-by: Suraj Shirvankar --- sztp-agent/cmd/daemon.go | 3 +- sztp-agent/cmd/disable.go | 3 +- sztp-agent/cmd/enable.go | 3 +- sztp-agent/cmd/run.go | 3 +- sztp-agent/cmd/status.go | 3 +- sztp-agent/pkg/secureagent/agent.go | 39 +++++++++++++++++++-- sztp-agent/pkg/secureagent/agent_test.go | 5 ++- sztp-agent/pkg/secureagent/daemon.go | 25 +------------ sztp-agent/pkg/secureagent/daemon_test.go | 6 +++- sztp-agent/pkg/secureagent/progress_test.go | 1 + sztp-agent/pkg/secureagent/tls.go | 18 +--------- 11 files changed, 59 insertions(+), 50 deletions(-) diff --git a/sztp-agent/cmd/daemon.go b/sztp-agent/cmd/daemon.go index d309d9a6..276a22f9 100644 --- a/sztp-agent/cmd/daemon.go +++ b/sztp-agent/cmd/daemon.go @@ -59,7 +59,8 @@ func Daemon() *cobra.Command { return fmt.Errorf("must not be folder: %q", filePath) } } - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandDaemon() }, } diff --git a/sztp-agent/cmd/disable.go b/sztp-agent/cmd/disable.go index 3ce70a4d..fd4e285f 100644 --- a/sztp-agent/cmd/disable.go +++ b/sztp-agent/cmd/disable.go @@ -34,7 +34,8 @@ func Disable() *cobra.Command { Use: "disable", Short: "Run the disable command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandDisable() }, } diff --git a/sztp-agent/cmd/enable.go b/sztp-agent/cmd/enable.go index 745bd795..7bf1299c 100644 --- a/sztp-agent/cmd/enable.go +++ b/sztp-agent/cmd/enable.go @@ -34,7 +34,8 @@ func Enable() *cobra.Command { Use: "enable", Short: "Run the enable command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandEnable() }, } diff --git a/sztp-agent/cmd/run.go b/sztp-agent/cmd/run.go index f3b02c1f..df5407f7 100644 --- a/sztp-agent/cmd/run.go +++ b/sztp-agent/cmd/run.go @@ -59,7 +59,8 @@ func Run() *cobra.Command { return fmt.Errorf("must not be folder: %q", filePath) } } - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommand() }, } diff --git a/sztp-agent/cmd/status.go b/sztp-agent/cmd/status.go index cf5043a7..bff8842f 100644 --- a/sztp-agent/cmd/status.go +++ b/sztp-agent/cmd/status.go @@ -34,7 +34,8 @@ func Status() *cobra.Command { Use: "status", Short: "Run the status command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandStatus() }, } diff --git a/sztp-agent/pkg/secureagent/agent.go b/sztp-agent/pkg/secureagent/agent.go index e248dc42..5230c760 100644 --- a/sztp-agent/pkg/secureagent/agent.go +++ b/sztp-agent/pkg/secureagent/agent.go @@ -9,6 +9,13 @@ Copyright (C) 2022 Red Hat. // Package secureagent implements the secure agent package secureagent +import ( + "crypto/tls" + "crypto/x509" + "net/http" + "os" +) + const ( CONTENT_TYPE_YANG = "application/yang-data+json" OS_RELEASE_FILE = "/etc/os-release" @@ -68,6 +75,11 @@ type BootstrapServerErrorOutput struct { } `json:"ietf-restconf:errors"` } +type HttpClient interface { + Get(uri string) (*http.Response, error) + Do(req *http.Request) (*http.Response, error) +} + // Agent is the basic structure to define an agent instance type Agent struct { InputBootstrapURL string // Bootstrap complete URL given by USER @@ -83,10 +95,10 @@ type Agent struct { ProgressJSON ProgressJSON // ProgressJson structure BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure - + HttpClient HttpClient } -func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string) *Agent { +func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpClient) *Agent { return &Agent{ InputBootstrapURL: bootstrapURL, BootstrapURL: "", @@ -101,6 +113,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, + HttpClient: httpClient, } } @@ -171,3 +184,25 @@ func (a *Agent) SetContentTypeReq(ct string) { func (a *Agent) SetProgressJSON(p ProgressJSON) { a.ProgressJSON = p } + +func NewHttpClient(bootstrapTrustAnchorCert string, deviceEndEntityCert string, devicePrivateKey string) http.Client { + caCert, _ := os.ReadFile(bootstrapTrustAnchorCert) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + cert, _ := tls.LoadX509KeyPair(deviceEndEntityCert, devicePrivateKey) + client := http.Client{ + CheckRedirect: func(r *http.Request, _ []*http.Request) error { + r.URL.Opaque = r.URL.Path + return nil + }, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, // TODO: remove skip verify + RootCAs: caCertPool, + Certificates: []tls.Certificate{cert}, + }, + }, + } + return client +} diff --git a/sztp-agent/pkg/secureagent/agent_test.go b/sztp-agent/pkg/secureagent/agent_test.go index ad2ebb09..e8234a8a 100644 --- a/sztp-agent/pkg/secureagent/agent_test.go +++ b/sztp-agent/pkg/secureagent/agent_test.go @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat. package secureagent import ( + "net/http" "reflect" "testing" ) @@ -829,6 +830,7 @@ func TestNewAgent(t *testing.T) { deviceEndEntityCert string bootstrapTrustAnchorCert string } + client := http.Client{} tests := []struct { name string args args @@ -856,12 +858,13 @@ func TestNewAgent(t *testing.T) { ContentTypeReq: "application/yang-data+json", InputJSONContent: generateInputJSONContent(), DhcpLeaseFile: "TestDhcpLeaseFile", + HttpClient: &client, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert); !reflect.DeepEqual(got, tt.want) { + if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) { t.Errorf("NewAgent() = %v, want %v", got, tt.want) } }) diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index ef775f95..a7778ea7 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -10,8 +10,6 @@ package secureagent import ( "bytes" - "crypto/tls" - "crypto/x509" "encoding/asn1" "encoding/base64" "encoding/json" @@ -19,7 +17,6 @@ import ( "fmt" "io" "log" - "net/http" "net/url" "os" "os/exec" @@ -196,27 +193,7 @@ func (a *Agent) downloadAndValidateImage() error { return err } - caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert()) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey()) - - check := http.Client{ - CheckRedirect: func(r *http.Request, _ []*http.Request) error { - r.URL.Opaque = r.URL.Path - return nil - }, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - //nolint:gosec - InsecureSkipVerify: true, // TODO: remove skip verify - RootCAs: caCertPool, - Certificates: []tls.Certificate{cert}, - }, - }, - } - - response, err := check.Get(item) + response, err := a.HttpClient.Get(item) if err != nil { return err } diff --git a/sztp-agent/pkg/secureagent/daemon_test.go b/sztp-agent/pkg/secureagent/daemon_test.go index eb7d968b..ea784aa3 100644 --- a/sztp-agent/pkg/secureagent/daemon_test.go +++ b/sztp-agent/pkg/secureagent/daemon_test.go @@ -341,6 +341,7 @@ func TestAgent_doReqBootstrap(t *testing.T) { ContentTypeReq: tt.fields.ContentTypeReq, InputJSONContent: tt.fields.InputJSONContent, DhcpLeaseFile: tt.fields.DhcpLeaseFile, + HttpClient: &http.Client{}, } if err := a.doRequestBootstrapServerOnboardingInfo(); (err != nil) != tt.wantErr { t.Errorf("doRequestBootstrapServer() error = %v, wantErr %v", err, tt.wantErr) @@ -359,7 +360,6 @@ func TestAgent_downloadAndValidateImage(t *testing.T) { } })) defer svr.Close() - type fields struct { BootstrapURL string SerialNumber string @@ -638,6 +638,7 @@ func TestAgent_downloadAndValidateImage(t *testing.T) { }, } for _, tt := range tests { + deleteTempTestFile(ARTIFACTS_PATH + "/imageOK") t.Run(tt.name, func(t *testing.T) { a := &Agent{ BootstrapURL: tt.fields.BootstrapURL, @@ -652,6 +653,7 @@ func TestAgent_downloadAndValidateImage(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: svr.Client(), } if err := a.downloadAndValidateImage(); (err != nil) != tt.wantErr { t.Errorf("downloadAndValidateImage() error = %v, wantErr %v", err, tt.wantErr) @@ -807,6 +809,7 @@ func TestAgent_copyConfigurationFile(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: &http.Client{}, } if err := a.copyConfigurationFile(); (err != nil) != tt.wantErr { t.Errorf("copyConfigurationFile() error = %v, wantErr %v", err, tt.wantErr) @@ -1024,6 +1027,7 @@ func TestAgent_launchScriptsConfiguration(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: &http.Client{}, } if err := a.launchScriptsConfiguration(tt.args.typeOf); (err != nil) != tt.wantErr { t.Errorf("launchScriptsConfiguration() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/progress_test.go b/sztp-agent/pkg/secureagent/progress_test.go index 7a55c0cb..a821dcf9 100644 --- a/sztp-agent/pkg/secureagent/progress_test.go +++ b/sztp-agent/pkg/secureagent/progress_test.go @@ -158,6 +158,7 @@ func TestAgent_doReportProgress(t *testing.T) { InputJSONContent: tt.fields.InputJSONContent, DhcpLeaseFile: tt.fields.DhcpLeaseFile, ProgressJSON: tt.fields.ProgressJSON, + HttpClient: &http.Client{}, } if err := a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated"); (err != nil) != tt.wantErr { t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/tls.go b/sztp-agent/pkg/secureagent/tls.go index 44bdaf68..b33d7636 100644 --- a/sztp-agent/pkg/secureagent/tls.go +++ b/sztp-agent/pkg/secureagent/tls.go @@ -10,14 +10,11 @@ package secureagent import ( "bytes" - "crypto/tls" - "crypto/x509" "encoding/json" "errors" "io" "log" "net/http" - "os" "strconv" "strings" ) @@ -38,20 +35,7 @@ func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapSe r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword()) r.Header.Add("Content-Type", a.GetContentTypeReq()) - caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert()) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey()) - - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ //nolint:gosec - RootCAs: caCertPool, - Certificates: []tls.Certificate{cert}, - }, - }, - } - res, err := client.Do(r) + res, err := a.HttpClient.Do(r) if err != nil { log.Println("Error doing the request", err.Error()) return nil, err