From 065d27d510ac28474a2b0792df6aac64a108eb97 Mon Sep 17 00:00:00 2001 From: Suraj Shirvankar Date: Mon, 5 Aug 2024 19:26:26 +0200 Subject: [PATCH] feat: skip image download if it exists Signed-off-by: Suraj Shirvankar --- sztp-agent/cmd/daemon.go | 3 +- sztp-agent/cmd/disable.go | 3 +- sztp-agent/cmd/enable.go | 3 +- sztp-agent/cmd/run.go | 3 +- sztp-agent/cmd/status.go | 3 +- sztp-agent/pkg/secureagent/agent.go | 39 ++++- sztp-agent/pkg/secureagent/agent_test.go | 5 +- sztp-agent/pkg/secureagent/daemon.go | 155 +++++++++++--------- sztp-agent/pkg/secureagent/daemon_test.go | 106 ++++++++++++- sztp-agent/pkg/secureagent/progress_test.go | 1 + sztp-agent/pkg/secureagent/tls.go | 18 +-- 11 files changed, 243 insertions(+), 96 deletions(-) diff --git a/sztp-agent/cmd/daemon.go b/sztp-agent/cmd/daemon.go index d309d9a6..276a22f9 100644 --- a/sztp-agent/cmd/daemon.go +++ b/sztp-agent/cmd/daemon.go @@ -59,7 +59,8 @@ func Daemon() *cobra.Command { return fmt.Errorf("must not be folder: %q", filePath) } } - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandDaemon() }, } diff --git a/sztp-agent/cmd/disable.go b/sztp-agent/cmd/disable.go index 3ce70a4d..fd4e285f 100644 --- a/sztp-agent/cmd/disable.go +++ b/sztp-agent/cmd/disable.go @@ -34,7 +34,8 @@ func Disable() *cobra.Command { Use: "disable", Short: "Run the disable command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandDisable() }, } diff --git a/sztp-agent/cmd/enable.go b/sztp-agent/cmd/enable.go index 745bd795..7bf1299c 100644 --- a/sztp-agent/cmd/enable.go +++ b/sztp-agent/cmd/enable.go @@ -34,7 +34,8 @@ func Enable() *cobra.Command { Use: "enable", Short: "Run the enable command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandEnable() }, } diff --git a/sztp-agent/cmd/run.go b/sztp-agent/cmd/run.go index f3b02c1f..df5407f7 100644 --- a/sztp-agent/cmd/run.go +++ b/sztp-agent/cmd/run.go @@ -59,7 +59,8 @@ func Run() *cobra.Command { return fmt.Errorf("must not be folder: %q", filePath) } } - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommand() }, } diff --git a/sztp-agent/cmd/status.go b/sztp-agent/cmd/status.go index cf5043a7..bff8842f 100644 --- a/sztp-agent/cmd/status.go +++ b/sztp-agent/cmd/status.go @@ -34,7 +34,8 @@ func Status() *cobra.Command { Use: "status", Short: "Run the status command", RunE: func(_ *cobra.Command, _ []string) error { - a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert) + client := secureagent.NewHttpClient(bootstrapTrustAnchorCert, deviceEndEntityCert, devicePrivateKey) + a := secureagent.NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, &client) return a.RunCommandStatus() }, } diff --git a/sztp-agent/pkg/secureagent/agent.go b/sztp-agent/pkg/secureagent/agent.go index e248dc42..ed71dfdd 100644 --- a/sztp-agent/pkg/secureagent/agent.go +++ b/sztp-agent/pkg/secureagent/agent.go @@ -9,6 +9,13 @@ Copyright (C) 2022 Red Hat. // Package secureagent implements the secure agent package secureagent +import ( + "crypto/tls" + "crypto/x509" + "net/http" + "os" +) + const ( CONTENT_TYPE_YANG = "application/yang-data+json" OS_RELEASE_FILE = "/etc/os-release" @@ -68,6 +75,11 @@ type BootstrapServerErrorOutput struct { } `json:"ietf-restconf:errors"` } +type HttpGetter interface { + Get(uri string) (*http.Response, error) + Do(req *http.Request) (*http.Response, error) +} + // Agent is the basic structure to define an agent instance type Agent struct { InputBootstrapURL string // Bootstrap complete URL given by USER @@ -83,10 +95,10 @@ type Agent struct { ProgressJSON ProgressJSON // ProgressJson structure BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure - + HttpClient HttpGetter } -func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string) *Agent { +func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert string, httpClient HttpGetter) *Agent { return &Agent{ InputBootstrapURL: bootstrapURL, BootstrapURL: "", @@ -101,6 +113,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, + HttpClient: httpClient, } } @@ -171,3 +184,25 @@ func (a *Agent) SetContentTypeReq(ct string) { func (a *Agent) SetProgressJSON(p ProgressJSON) { a.ProgressJSON = p } + +func NewHttpClient(bootstrapTrustAnchorCert string, deviceEndEntityCert string, devicePrivateKey string) http.Client { + caCert, _ := os.ReadFile(bootstrapTrustAnchorCert) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + cert, _ := tls.LoadX509KeyPair(deviceEndEntityCert, devicePrivateKey) + client := http.Client{ + CheckRedirect: func(r *http.Request, _ []*http.Request) error { + r.URL.Opaque = r.URL.Path + return nil + }, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, // TODO: remove skip verify + RootCAs: caCertPool, + Certificates: []tls.Certificate{cert}, + }, + }, + } + return client +} diff --git a/sztp-agent/pkg/secureagent/agent_test.go b/sztp-agent/pkg/secureagent/agent_test.go index ad2ebb09..e8234a8a 100644 --- a/sztp-agent/pkg/secureagent/agent_test.go +++ b/sztp-agent/pkg/secureagent/agent_test.go @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat. package secureagent import ( + "net/http" "reflect" "testing" ) @@ -829,6 +830,7 @@ func TestNewAgent(t *testing.T) { deviceEndEntityCert string bootstrapTrustAnchorCert string } + client := http.Client{} tests := []struct { name string args args @@ -856,12 +858,13 @@ func TestNewAgent(t *testing.T) { ContentTypeReq: "application/yang-data+json", InputJSONContent: generateInputJSONContent(), DhcpLeaseFile: "TestDhcpLeaseFile", + HttpClient: &client, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert); !reflect.DeepEqual(got, tt.want) { + if got := NewAgent(tt.args.bootstrapURL, tt.args.serialNumber, tt.args.dhcpLeaseFile, tt.args.devicePassword, tt.args.devicePrivateKey, tt.args.deviceEndEntityCert, tt.args.bootstrapTrustAnchorCert, &client); !reflect.DeepEqual(got, tt.want) { t.Errorf("NewAgent() = %v, want %v", got, tt.want) } }) diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index ef775f95..660d1a4f 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -10,8 +10,6 @@ package secureagent import ( "bytes" - "crypto/tls" - "crypto/x509" "encoding/asn1" "encoding/base64" "encoding/json" @@ -19,7 +17,6 @@ import ( "fmt" "io" "log" - "net/http" "net/url" "os" "os/exec" @@ -181,88 +178,106 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { return errri } -//nolint:funlen -func (a *Agent) downloadAndValidateImage() error { - log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI) - _ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated") - // Download the image from DownloadURI and save it to a file - a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix()) - for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI { - // TODO: maybe need to file download to a function in util.go - log.Printf("[INFO] Downloading Image %v", item) - // Create a empty file - file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + filepath.Base(item)) - if err != nil { - return err - } +func (a *Agent) downloadArtifact(uri string) (*os.File, error) { + file, err := os.CreateTemp("", filepath.Base(uri)) + if err != nil { + return nil, err + } - caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert()) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey()) + response, err := a.HttpClient.Get(uri) + if err != nil { + return nil, err + } + sizeorigin, _ := strconv.Atoi(response.Header.Get("Content-Length")) + downloadSize := int64(sizeorigin) + log.Printf("[INFO] Downloading the image with size: %v", downloadSize) - check := http.Client{ - CheckRedirect: func(r *http.Request, _ []*http.Request) error { - r.URL.Opaque = r.URL.Path - return nil - }, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - //nolint:gosec - InsecureSkipVerify: true, // TODO: remove skip verify - RootCAs: caCertPool, - Certificates: []tls.Certificate{cert}, - }, - }, + if response.StatusCode != 200 { + return nil, errors.New("received non 200 response code") + } + size, err := io.Copy(file, response.Body) + if err != nil { + return nil, err + } + defer func() { + if err := file.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) } + }() + defer func() { + if err := response.Body.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) + } + }() + log.Printf("[INFO] Downloaded file: %s with size: %d", file.Name(), size) + return file, nil +} + +func (a *Agent) validateImage(filePath string, algorithm string, expected string) error { + switch algorithm { + case "ietf-sztp-conveyed-info:sha-256": + checksum, err := calculateSHA256File(filePath) - response, err := check.Get(item) if err != nil { - return err + log.Println("[ERROR] Could not calculate checksum", err) } + log.Println("calculated: " + checksum) + log.Println("expected : " + expected) - sizeorigin, _ := strconv.Atoi(response.Header.Get("Content-Length")) - downloadSize := int64(sizeorigin) - log.Printf("[INFO] Downloading the image with size: %v", downloadSize) - - if response.StatusCode != 200 { - return errors.New("received non 200 response code") + if checksum != expected { + return errors.New("checksum mismatch") } - size, err := io.Copy(file, response.Body) - if err != nil { - return err + log.Println("[INFO] Checksum verified successfully") + return nil + default: + return errors.New("unsupported hash algorithm") + } +} + +func (a *Agent) artifactExists(item string, algorithm string, expected string) bool { + filePath := ARTIFACTS_PATH + filepath.Base(item) + _, err := os.Stat(filePath) + if err != nil { + return false + } + err = a.validateImage(filePath, algorithm, expected) + return err == nil +} + +//nolint:funlen +func (a *Agent) downloadAndValidateImage() error { + bootImage := a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage + log.Printf("[INFO] Starting the Download Image: %v", bootImage.DownloadURI) + _ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated") + // Download the image from DownloadURI and save it to a file + for i, item := range bootImage.DownloadURI { + if len(bootImage.ImageVerification) <= i { + return errors.New("invalid verification") } - defer func() { - if err := file.Close(); err != nil { - log.Println("[ERROR] Error when closing:", err) - } - }() - defer func() { - if err := response.Body.Close(); err != nil { - log.Println("[ERROR] Error when closing:", err) - } - }() - log.Printf("[INFO] Downloaded file: %s with size: %d", ARTIFACTS_PATH+a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference+filepath.Base(item), size) - log.Println("[INFO] Verify the file checksum: ", ARTIFACTS_PATH+a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference+filepath.Base(item)) - switch a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.ImageVerification[i].HashAlgorithm { - case "ietf-sztp-conveyed-info:sha-256": - filePath := ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + filepath.Base(item) - checksum, err := calculateSHA256File(filePath) - original := strings.ReplaceAll(a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.ImageVerification[i].HashValue, ":", "") + imageVerification := bootImage.ImageVerification[i] + expected := strings.ReplaceAll(imageVerification.HashValue, ":", "") + algorithm := imageVerification.HashAlgorithm + + if a.artifactExists(item, algorithm, expected) { + log.Printf("[INFO] Image %v already exists", item) + } else { + log.Printf("[INFO] Downloading Image %v", item) + file, err := a.downloadArtifact(item) if err != nil { - log.Println("[ERROR] Could not calculate checksum", err) + return err } - log.Println("calculated: " + checksum) - log.Println("expected : " + original) - if checksum != original { - return errors.New("checksum mismatch") + log.Println("[INFO] Verify the file checksum: ", file.Name()) + err = a.validateImage(file.Name(), algorithm, expected) + if err != nil { + return err } - log.Println("[INFO] Checksum verified successfully") + + log.Printf("[INFO] Moving file %s to %s", file.Name(), ARTIFACTS_PATH+filepath.Base(item)) + _ = os.Rename(file.Name(), ARTIFACTS_PATH+filepath.Base(item)) _ = a.doReportProgress(ProgressTypeBootImageComplete, "BootImage Complete") + return nil - default: - return errors.New("unsupported hash algorithm") } } return nil diff --git a/sztp-agent/pkg/secureagent/daemon_test.go b/sztp-agent/pkg/secureagent/daemon_test.go index eb7d968b..a1217b26 100644 --- a/sztp-agent/pkg/secureagent/daemon_test.go +++ b/sztp-agent/pkg/secureagent/daemon_test.go @@ -7,10 +7,12 @@ package secureagent import ( "encoding/json" "fmt" + "io" "log" "net/http" "net/http/httptest" "os" + "strings" "testing" ) @@ -341,6 +343,7 @@ func TestAgent_doReqBootstrap(t *testing.T) { ContentTypeReq: tt.fields.ContentTypeReq, InputJSONContent: tt.fields.InputJSONContent, DhcpLeaseFile: tt.fields.DhcpLeaseFile, + HttpClient: &http.Client{}, } if err := a.doRequestBootstrapServerOnboardingInfo(); (err != nil) != tt.wantErr { t.Errorf("doRequestBootstrapServer() error = %v, wantErr %v", err, tt.wantErr) @@ -349,6 +352,19 @@ func TestAgent_doReqBootstrap(t *testing.T) { } } +type MockClient struct { + GetFunc func(uri string) (*http.Response, error) + DoFunc func(req *http.Request) (*http.Response, error) +} + +func (m *MockClient) Do(req *http.Request) (*http.Response, error) { + return m.DoFunc(req) +} + +func (m *MockClient) Get(uri string) (*http.Response, error) { + return m.GetFunc(uri) +} + //nolint:funlen func TestAgent_downloadAndValidateImage(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -359,7 +375,6 @@ func TestAgent_downloadAndValidateImage(t *testing.T) { } })) defer svr.Close() - type fields struct { BootstrapURL string SerialNumber string @@ -638,6 +653,7 @@ func TestAgent_downloadAndValidateImage(t *testing.T) { }, } for _, tt := range tests { + deleteTempTestFile(ARTIFACTS_PATH + "/imageOK") t.Run(tt.name, func(t *testing.T) { a := &Agent{ BootstrapURL: tt.fields.BootstrapURL, @@ -652,12 +668,98 @@ func TestAgent_downloadAndValidateImage(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: svr.Client(), } if err := a.downloadAndValidateImage(); (err != nil) != tt.wantErr { t.Errorf("downloadAndValidateImage() error = %v, wantErr %v", err, tt.wantErr) } }) } + calls := 0 + httpClient := MockClient{ + GetFunc: func(_ string) (*http.Response, error) { + calls++ + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }, + DoFunc: func(_ *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(""))}, nil + }, + } + a := &Agent{ + BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{ + IetfSztpConveyedInfoOnboardingInformation: struct { + InfoTimestampReference string + BootImage struct { + DownloadURI []string `json:"download-uri"` + ImageVerification []struct { + HashAlgorithm string `json:"hash-algorithm"` + HashValue string `json:"hash-value"` + } `json:"image-verification"` + } `json:"boot-image"` + PreConfigurationScript string `json:"pre-configuration-script"` + ConfigurationHandling string `json:"configuration-handling"` + Configuration string `json:"configuration"` + PostConfigurationScript string `json:"post-configuration-script"` + }{ + InfoTimestampReference: "TIMESTAMP", + BootImage: struct { + DownloadURI []string `json:"download-uri"` + ImageVerification []struct { + HashAlgorithm string `json:"hash-algorithm"` + HashValue string `json:"hash-value"` + } `json:"image-verification"` + }{ + DownloadURI: []string{svr.URL + "/imageOK"}, + ImageVerification: []struct { + HashAlgorithm string `json:"hash-algorithm"` + HashValue string `json:"hash-value"` + }{{ + HashAlgorithm: "ietf-sztp-conveyed-info:sha-256", + HashValue: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + }}, + }, + PreConfigurationScript: "", + ConfigurationHandling: "", + Configuration: "", + PostConfigurationScript: "", + }, + }, + BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, + HttpClient: &httpClient, + } + t.Run("OK case with cached file", func(t *testing.T) { + calls = 0 + deleteTempTestFile(ARTIFACTS_PATH + "/imageOK") + // Initiate cache download + err := a.downloadAndValidateImage() + if err != nil { + t.Errorf("downloadAndValidateImage() error = %v", err) + } + err = a.downloadAndValidateImage() + if err != nil { + t.Errorf("downloadAndValidateImage() error = %v", err) + } + if calls != 1 { + t.Errorf("downloadAndValidateImage() called httpclient more than 1 times, Called %d", calls) + } + }) + t.Run("OK case with cached file with different signature", func(t *testing.T) { + calls = 0 + deleteTempTestFile(ARTIFACTS_PATH + "/imageOK") + // Initiate cache download + err := a.downloadAndValidateImage() + if err != nil { + t.Errorf("downloadAndValidateImage() error = %v", err) + } + _ = os.WriteFile(ARTIFACTS_PATH+"/imageOK", []byte("test"), 0600) + err = a.downloadAndValidateImage() + if err != nil { + t.Errorf("downloadAndValidateImage() error = %v", err) + } + if calls != 2 { + t.Errorf("downloadAndValidateImage() should call httpclient two time, Called %d", calls) + } + }) } // nolint:funlen @@ -807,6 +909,7 @@ func TestAgent_copyConfigurationFile(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: &http.Client{}, } if err := a.copyConfigurationFile(); (err != nil) != tt.wantErr { t.Errorf("copyConfigurationFile() error = %v, wantErr %v", err, tt.wantErr) @@ -1024,6 +1127,7 @@ func TestAgent_launchScriptsConfiguration(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, + HttpClient: &http.Client{}, } if err := a.launchScriptsConfiguration(tt.args.typeOf); (err != nil) != tt.wantErr { t.Errorf("launchScriptsConfiguration() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/progress_test.go b/sztp-agent/pkg/secureagent/progress_test.go index 7a55c0cb..a821dcf9 100644 --- a/sztp-agent/pkg/secureagent/progress_test.go +++ b/sztp-agent/pkg/secureagent/progress_test.go @@ -158,6 +158,7 @@ func TestAgent_doReportProgress(t *testing.T) { InputJSONContent: tt.fields.InputJSONContent, DhcpLeaseFile: tt.fields.DhcpLeaseFile, ProgressJSON: tt.fields.ProgressJSON, + HttpClient: &http.Client{}, } if err := a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated"); (err != nil) != tt.wantErr { t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr) diff --git a/sztp-agent/pkg/secureagent/tls.go b/sztp-agent/pkg/secureagent/tls.go index 44bdaf68..b33d7636 100644 --- a/sztp-agent/pkg/secureagent/tls.go +++ b/sztp-agent/pkg/secureagent/tls.go @@ -10,14 +10,11 @@ package secureagent import ( "bytes" - "crypto/tls" - "crypto/x509" "encoding/json" "errors" "io" "log" "net/http" - "os" "strconv" "strings" ) @@ -38,20 +35,7 @@ func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapSe r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword()) r.Header.Add("Content-Type", a.GetContentTypeReq()) - caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert()) - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey()) - - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ //nolint:gosec - RootCAs: caCertPool, - Certificates: []tls.Certificate{cert}, - }, - }, - } - res, err := client.Do(r) + res, err := a.HttpClient.Do(r) if err != nil { log.Println("Error doing the request", err.Error()) return nil, err