From 8f455c5d469e94c1a4e9efa1f51c5bbfe04bc6b0 Mon Sep 17 00:00:00 2001 From: Boris Glimcher Date: Fri, 7 Jun 2024 06:06:10 +0300 Subject: [PATCH] fix(sztp): send ssh key when onboarding completed Signed-off-by: Boris Glimcher --- sztp-agent/pkg/secureagent/agent_test.go | 2 +- sztp-agent/pkg/secureagent/daemon.go | 43 +++++++++++++++++------ sztp-agent/pkg/secureagent/daemon_test.go | 2 +- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/sztp-agent/pkg/secureagent/agent_test.go b/sztp-agent/pkg/secureagent/agent_test.go index db5e2b94..e98832a7 100644 --- a/sztp-agent/pkg/secureagent/agent_test.go +++ b/sztp-agent/pkg/secureagent/agent_test.go @@ -1019,7 +1019,7 @@ func TestAgent_SetProgressJson(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, } a.SetProgressJSON(tt.args.p) - if ! reflect.DeepEqual(a.GetProgressJSON(), tt.args.p) { + if !reflect.DeepEqual(a.GetProgressJSON(), tt.args.p) { t.Errorf("SetProgressJson = %v, want %v", a.GetProgressJSON(), tt.args.p) } }) diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index bff9aa4b..fbd11ab2 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -65,7 +65,7 @@ func (a *Agent) RunCommandDaemon() error { if err != nil { return err } - // _ = a.doReportProgress(ProgressTypeBootstrapComplete) + _ = a.doReportProgress(ProgressTypeBootstrapComplete, true) return nil } @@ -88,10 +88,10 @@ func (a *Agent) getBootstrapURL() error { return nil } -func (a *Agent) doReportProgress(s ProgressType) error { +func (a *Agent) doReportProgress(s ProgressType, needssh bool) error { log.Println("[INFO] Starting the Report Progress request.") url := strings.ReplaceAll(a.GetBootstrapURL(), "get-bootstrapping-data", "report-progress") - a.SetProgressJSON(ProgressJSON{ + p := ProgressJSON{ IetfSztpBootstrapServerInput: struct { ProgressType string `json:"progress-type"` Message string `json:"message"` @@ -105,7 +105,28 @@ func (a *Agent) doReportProgress(s ProgressType) error { ProgressType: s.String(), Message: "message sent via JSON", }, - }) + } + if needssh { + // TODO: generate real key here + encodedKey := base64.StdEncoding.EncodeToString([]byte("mysshpass")) + p.IetfSztpBootstrapServerInput.SSHHostKeys = struct { + SSHHostKey []struct { + Algorithm string `json:"algorithm"` + KeyData string `json:"key-data"` + } `json:"ssh-host-key,omitempty"` + }{ + SSHHostKey: []struct { + Algorithm string `json:"algorithm"` + KeyData string `json:"key-data"` + }{ + { + Algorithm: "ssh-rsa", + KeyData: encodedKey, + }, + }, + } + } + a.SetProgressJSON(p) inputJSON, _ := json.Marshal(a.GetProgressJSON()) res, err := a.doTLSRequest(string(inputJSON), url, true) if err != nil { @@ -150,7 +171,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { return err } log.Println("[INFO] Response retrieved successfully") - _ = a.doReportProgress(ProgressTypeBootstrapInitiated) + _ = a.doReportProgress(ProgressTypeBootstrapInitiated, false) crypto := res.IetfSztpBootstrapServerOutput.ConveyedInformation newVal, err := base64.StdEncoding.DecodeString(crypto) if err != nil { @@ -190,7 +211,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { //nolint:funlen func (a *Agent) downloadAndValidateImage() error { log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI) - _ = a.doReportProgress(ProgressTypeBootImageInitiated) + _ = a.doReportProgress(ProgressTypeBootImageInitiated, false) // Download the image from DownloadURI and save it to a file a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix()) for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI { @@ -251,7 +272,7 @@ func (a *Agent) downloadAndValidateImage() error { return errors.New("Checksum mismatch") } log.Println("[INFO] Checksum verified successfully") - _ = a.doReportProgress(ProgressTypeBootImageComplete) + _ = a.doReportProgress(ProgressTypeBootImageComplete, false) return nil default: return errors.New("Unsupported hash algorithm") @@ -262,7 +283,7 @@ func (a *Agent) downloadAndValidateImage() error { func (a *Agent) copyConfigurationFile() error { log.Println("[INFO] Starting the Copy Configuration.") - _ = a.doReportProgress(ProgressTypeConfigInitiated) + _ = a.doReportProgress(ProgressTypeConfigInitiated, false) // Copy the configuration file to the device file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + "-config") if err != nil { @@ -283,7 +304,7 @@ func (a *Agent) copyConfigurationFile() error { return err } log.Println("[INFO] Configuration file copied successfully") - _ = a.doReportProgress(ProgressTypeConfigComplete) + _ = a.doReportProgress(ProgressTypeConfigComplete, false) return nil } @@ -303,7 +324,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { reportEnd = ProgressTypePreScriptComplete } log.Println("[INFO] Starting the " + scriptName + "-configuration.") - _ = a.doReportProgress(reportStart) + _ = a.doReportProgress(reportStart, false) file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + scriptName + "configuration.sh") if err != nil { log.Println("[ERROR] creating the "+scriptName+"-configuration script", err.Error()) @@ -330,7 +351,7 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { return err } log.Println(string(out)) // remove it - _ = a.doReportProgress(reportEnd) + _ = a.doReportProgress(reportEnd, false) log.Println("[INFO] " + scriptName + "-Configuration script executed successfully") return nil } diff --git a/sztp-agent/pkg/secureagent/daemon_test.go b/sztp-agent/pkg/secureagent/daemon_test.go index 75810404..18c318d9 100644 --- a/sztp-agent/pkg/secureagent/daemon_test.go +++ b/sztp-agent/pkg/secureagent/daemon_test.go @@ -412,7 +412,7 @@ func TestAgent_doReportProgress(t *testing.T) { DhcpLeaseFile: tt.fields.DhcpLeaseFile, ProgressJSON: tt.fields.ProgressJSON, } - if err := a.doReportProgress(ProgressTypeBootstrapInitiated); (err != nil) != tt.wantErr { + if err := a.doReportProgress(ProgressTypeBootstrapInitiated, false); (err != nil) != tt.wantErr { t.Errorf("doReportProgress() error = %v, wantErr %v", err, tt.wantErr) } })