diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index e646592..ef775f9 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -10,7 +10,6 @@ package secureagent import ( "bytes" - "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/asn1" @@ -246,28 +245,17 @@ func (a *Agent) downloadAndValidateImage() error { log.Printf("[INFO] Downloaded file: %s with size: %d", ARTIFACTS_PATH+a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference+filepath.Base(item), size) log.Println("[INFO] Verify the file checksum: ", ARTIFACTS_PATH+a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference+filepath.Base(item)) - // TODO: maybe need to move sha calculatinos to a function in util.go switch a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.ImageVerification[i].HashAlgorithm { case "ietf-sztp-conveyed-info:sha-256": - f, err := os.Open(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + filepath.Base(item)) + filePath := ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + filepath.Base(item) + checksum, err := calculateSHA256File(filePath) + original := strings.ReplaceAll(a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.ImageVerification[i].HashValue, ":", "") if err != nil { - log.Panic(err) - return err - } - defer func() { - if err := f.Close(); err != nil { - log.Println("[ERROR] Error when closing:", err) - } - }() - h := sha256.New() - if _, err := io.Copy(h, f); err != nil { - return err + log.Println("[ERROR] Could not calculate checksum", err) } - sum := fmt.Sprintf("%x", h.Sum(nil)) - original := strings.ReplaceAll(a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.ImageVerification[i].HashValue, ":", "") - log.Println("calculated: " + sum) + log.Println("calculated: " + checksum) log.Println("expected : " + original) - if sum != original { + if checksum != original { return errors.New("checksum mismatch") } log.Println("[INFO] Checksum verified successfully") diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index b8cf22c..36a7630 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -9,8 +9,12 @@ Copyright (C) 2022 Red Hat. package secureagent import ( + "crypto/sha256" "encoding/json" + "fmt" + "io" "log" + "os" "strings" "github.com/go-ini/ini" @@ -63,3 +67,22 @@ func generateInputJSONContent() string { func replaceQuotes(input string) string { return strings.ReplaceAll(input, "\"", "") } + +func calculateSHA256File(filePath string) (string, error) { + f, err := os.Open(filePath) + if err != nil { + log.Panic(err) + return "", err + } + defer func() { + if err := f.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) + } + }() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + checkSum := fmt.Sprintf("%x", h.Sum(nil)) + return checkSum, nil +} diff --git a/sztp-agent/pkg/secureagent/utils_test.go b/sztp-agent/pkg/secureagent/utils_test.go index a3cd990..bf32fec 100644 --- a/sztp-agent/pkg/secureagent/utils_test.go +++ b/sztp-agent/pkg/secureagent/utils_test.go @@ -5,6 +5,7 @@ package secureagent import ( + "os" "testing" ) @@ -47,3 +48,31 @@ func Test_replaceQuotes(t *testing.T) { }) } } + +func Test_calculateSHA256File(t *testing.T) { + content := []byte("temporary file's content") + file, err := os.CreateTemp("", "example") + if err != nil { + t.Fatal("Failed to create file", err) + } + defer func() { + _ = os.Remove(file.Name()) + }() + + if _, err := file.Write(content); err != nil { + t.Fatal("Failed to write to file", err) + } + + if err := file.Close(); err != nil { + t.Fatal("Unable to close the file", err) + } + + checksum, err := calculateSHA256File(file.Name()) + if err != nil { + t.Fatal("Could not calculate SHA256", file.Name()) + } + expected := "df3ae2e9b295f790e12e6cf440ffc461d4660f266b84865f14c5508cf68e6f3d" + if checksum != expected { + t.Errorf("Checksum did not match %s %s", checksum, expected) + } +}