From ec0d453b990a956f5c06856d9c5989c726016a84 Mon Sep 17 00:00:00 2001 From: Sagar Muchhal Date: Wed, 26 Aug 2020 12:28:59 -0700 Subject: [PATCH] Adds support for passphrase protected ssh key This patch adds support for an instance of ssh-agent which is used to communicate (ssh/scp) with remote compute resources. Whenever the Starlark crashd_config() directive uses the `use_ssh_agent` parameter, a new instance of ssh.Agent is started and keys are added to this Agent instance. All ssh and scp operations then use this agent to connect to remote resources. This patch also includes: - some variable name changes for some Starlark builtins. - bumping of the test workflow to use go 1.15 since we use rely on new functionality introduced post go 1.14 Signed-off-by: Sagar Muchhal --- .github/workflows/compile-test.yaml | 4 +- exec/executor.go | 9 +- go.mod | 2 +- ssh/agent.go | 175 ++++++++++++++++++++++++++++ ssh/agent_test.go | 160 +++++++++++++++++++++++++ ssh/scp.go | 7 +- ssh/scp_test.go | 2 +- ssh/ssh.go | 15 ++- ssh/ssh_test.go | 4 +- ssh/test_support.go | 10 +- starlark/capa_provider.go | 2 +- starlark/capture.go | 23 ++-- starlark/copy_from.go | 23 ++-- starlark/crashd_config.go | 18 ++- starlark/crashd_config_test.go | 21 ++++ starlark/resources.go | 2 +- starlark/run.go | 22 +++- starlark/ssh_config.go | 32 +++-- starlark/starlark_exec.go | 19 ++- starlark/support.go | 4 + 20 files changed, 498 insertions(+), 56 deletions(-) create mode 100644 ssh/agent.go create mode 100644 ssh/agent_test.go diff --git a/.github/workflows/compile-test.yaml b/.github/workflows/compile-test.yaml index 873f1e3b..a0eaf963 100644 --- a/.github/workflows/compile-test.yaml +++ b/.github/workflows/compile-test.yaml @@ -6,10 +6,10 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.13 + - name: Set up Go 1.15 uses: actions/setup-go@v1 with: - go-version: 1.13.x + go-version: 1.15 id: go - name: Check out code into the Go module directory diff --git a/exec/executor.go b/exec/executor.go index c62e610c..0b105833 100644 --- a/exec/executor.go +++ b/exec/executor.go @@ -4,10 +4,10 @@ package exec import ( - "fmt" "io" "os" + "github.com/pkg/errors" "github.com/vmware-tanzu/crash-diagnostics/starlark" ) @@ -25,11 +25,12 @@ func Execute(name string, source io.Reader, args ArgMap) error { star.AddPredeclared("args", starStruct) } - if err := star.Exec(name, source); err != nil { - return fmt.Errorf("exec failed: %s", err) + err := star.Exec(name, source) + if err != nil { + err = errors.Wrap(err, "exec failed") } - return nil + return err } func ExecuteFile(file *os.File, args ArgMap) error { diff --git a/go.mod b/go.mod index 5b72ae1a..e0eefd5e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/vmware-tanzu/crash-diagnostics -go 1.12 +go 1.15 require ( github.com/imdario/mergo v0.3.7 // indirect diff --git a/ssh/agent.go b/ssh/agent.go new file mode 100644 index 00000000..8d6a365d --- /dev/null +++ b/ssh/agent.go @@ -0,0 +1,175 @@ +// Copyright (c) 2020 VMware, Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package ssh + +import ( + "bufio" + "fmt" + "io" + "strings" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/vladimirvivien/echo" +) + +// ssh-agent constant identifiers +const ( + AgentPidIdentifier = "SSH_AGENT_PID" + AuthSockIdentifier = "SSH_AUTH_SOCK" +) + +type Agent interface { + AddKey(keyPath string) error + RemoveKey(keyPath string) error + Stop() error + GetEnvVariables() string +} + +// agentInfo captures the connection information of the ssh-agent +type agentInfo map[string]string + +// agent represents an instance of the ssh-agent +type agent struct { + // Pid of the ssh-agent + Pid string + + // Authentication socket to communicate with the ssh-agent + AuthSockPath string + + // File paths of the keys added to the ssh-agent + KeyPaths []string +} + +// AddKey adds a key to the ssh-agent process +func (agent *agent) AddKey(keyPath string) error { + e := echo.New() + sshAddCmd := e.Prog.Avail("ssh-add") + if len(sshAddCmd) == 0 { + return errors.New("ssh-add not found") + } + + p := e.Env(agent.GetEnvVariables()). + RunProc(fmt.Sprintf("%s %s", sshAddCmd, keyPath)) + if err := p.Err(); err != nil { + return errors.Wrapf(err, "could not add key %s to ssh-agent", keyPath) + } + agent.KeyPaths = append(agent.KeyPaths, keyPath) + return nil +} + +// RemoveKey removes a key from the ssh-agent process +func (agent *agent) RemoveKey(keyPath string) error { + e := echo.New() + sshAddCmd := e.Prog.Avail("ssh-add") + if len(sshAddCmd) == 0 { + return errors.New("ssh-add not found") + } + + p := e.Env(agent.GetEnvVariables()). + RunProc(fmt.Sprintf("%s -d %s", sshAddCmd, keyPath)) + if err := p.Err(); err != nil { + return errors.Wrapf(err, "could not add key %s to ssh-agent", keyPath) + } + + return nil +} + +// Stop kills the ssh-agent process. +// It also tries to remove the added keys from the agent +func (agent *agent) Stop() error { + for _, path := range agent.KeyPaths { + logrus.Debugf("removing key from ssh-agent: %s", path) + err := agent.RemoveKey(path) + if err != nil { + logrus.Warnf("failed to remove SSH key from agent: %s", err) + } + } + + logrus.Debugf("stopping the ssh-agent with Pid: %s", agent.Pid) + p := echo.New().Env(agent.GetEnvVariables()).RunProc("ssh-agent -k") + + return p.Err() +} + +// GetEnvVariables returns the space separated key=value information used to communicate with the ssh-agent +func (agent *agent) GetEnvVariables() string { + return fmt.Sprintf("%s=%s %s=%s", AgentPidIdentifier, agent.Pid, AuthSockIdentifier, agent.AuthSockPath) +} + +// StartAgent starts the ssh-agent process and returns the SSH authentication parameters. +func StartAgent() (Agent, error) { + e := echo.New() + sshAgentCmd := e.Prog.Avail("ssh-agent") + if len(sshAgentCmd) == 0 { + return nil, fmt.Errorf("ssh-agent not found") + } + + p := e.RunProc(fmt.Sprintf("%s -s", sshAgentCmd)) + if p.Err() != nil { + return nil, errors.Wrap(p.Err(), "failed to start ssh agent") + } + + agentInfo, err := parseAgentInfo(p.Out()) + if err != nil { + return nil, err + } + if err := validateAgentInfo(agentInfo); err != nil { + return nil, err + } + + return agentFromInfo(agentInfo), nil +} + +// parseAgentInfo parses the output of ssh-agent -s to determine the information +// for the ssh authentication agent. +// example output: +// SSH_AUTH_SOCK=/foo/bar.1234; export SSH_AUTH_SOCK; +// SSH_AGENT_PID=4567; export SSH_AGENT_PID; +// echo Agent pid 4567; +func parseAgentInfo(info io.Reader) (agentInfo, error) { + agentInfo := map[string]string{} + + scanner := bufio.NewScanner(info) + if err := scanner.Err(); err != nil { + return agentInfo, err + } + + for scanner.Scan() { + line := scanner.Text() + // separate the line using the semi-colon as a separator + if equal := strings.Index(line, ";"); equal >= 0 { + s := strings.Split(line, ";")[0] + // check if any key=value pair is present + if equal := strings.Index(s, "="); equal >= 0 { + kv := strings.Split(s, "=") + // store the key-value pair in the map + agentInfo[kv[0]] = kv[1] + } + } + } + + return agentInfo, nil +} + +// validateAgentInfo checks whether the ssh-agent information is valid +func validateAgentInfo(info agentInfo) error { + if len(info) != 2 { + return errors.New("faulty ssh-agent identifier info") + } + for k, v := range info { + if !strings.Contains(strings.Join([]string{AgentPidIdentifier, AuthSockIdentifier}, ""), k) || len(v) == 0 { + return errors.New("faulty ssh-agent identifier info") + } + } + return nil +} + +// agentFromInfo parses the information map and returns an instance of agent +func agentFromInfo(agentInfo agentInfo) *agent { + return &agent{ + Pid: agentInfo[AgentPidIdentifier], + AuthSockPath: agentInfo[AuthSockIdentifier], + } +} diff --git a/ssh/agent_test.go b/ssh/agent_test.go new file mode 100644 index 00000000..b6ae761c --- /dev/null +++ b/ssh/agent_test.go @@ -0,0 +1,160 @@ +// Copyright (c) 2019 VMware, Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package ssh + +import ( + "bufio" + "regexp" + "strings" + "testing" + + "github.com/vladimirvivien/echo" +) + +func TestParseAndValidateAgentInfo(t *testing.T) { + tests := []struct { + name string + info string + shouldErr bool + }{ + { + name: "valid info", + shouldErr: false, + info: `SSH_AUTH_SOCK=/foo/bar.1234; export SSH_AUTH_SOCK; +SSH_AGENT_PID=4567; export SSH_AGENT_PID; +echo Agent pid 4567;`, + }, + { + name: "invalid info", + shouldErr: true, + info: `FOO=/foo/bar.1234; export BAR; +BLAH=4567; export BLOOP; +echo lorem ipsum 4567;`, + }, + { + name: "invalid info", + shouldErr: true, + info: `SSH_AUTH_SOCK=/foo/bar.1234; export SSH_AUTH_SOCK; +BLAH=4567; export BLOOP; +echo lorem ipsum 4567;`, + }, + { + name: "invalid info", + shouldErr: true, + info: `FOO=/foo/bar.1234; export BAR; +SSH_AGENT_PID=4567; export SSH_AGENT_PID; +echo lorem ipsum 4567;`, + }, + { + name: "invalid info", + shouldErr: true, + info: `lorem ipsum 1; +lorem ipsum 2.`, + }, + { + name: "invalid info", + shouldErr: true, + info: "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + agentInfo, err := parseAgentInfo(strings.NewReader(test.info)) + if err != nil { + t.Fail() + } + err = validateAgentInfo(agentInfo) + if err != nil && !test.shouldErr { + // unexpected failures + t.Fail() + } else if !test.shouldErr { + if _, ok := agentInfo[AgentPidIdentifier]; !ok { + t.Fail() + } + if _, ok := agentInfo[AuthSockIdentifier]; !ok { + t.Fail() + } + } else { + // asserting error scenarios + if err == nil { + t.Fail() + } + } + }) + } +} + +func TestStartAgent(t *testing.T) { + a, err := StartAgent() + if err != nil || a == nil { + t.Fatalf("error should be nil and agent should not be nil: %v", err) + } + out := echo.New().Run("ps -ax") + if !strings.Contains(out, "ssh-agent") { + t.Fatal("no ssh-agent process found") + } + + failed := true + scanner := bufio.NewScanner(strings.NewReader(out)) + for scanner.Scan() { + line := scanner.Text() + if strings.Contains(line, "ssh-agent") { + pid := strings.Split(strings.TrimSpace(line), " ")[0] + // set failed to false if correct ssh-agent process is found + agentStruct, _ := a.(*agent) + if pid == agentStruct.Pid { + failed = false + } + } + } + if failed { + t.Fatal("could not find agent with correct Pid") + } + + t.Cleanup(func() { + _ = a.Stop() + }) +} + +func TestAgent(t *testing.T) { + a, err := StartAgent() + if err != nil { + t.Fatalf("failed to start agent: %v", err) + } + + tests := []struct { + name string + assert func(*testing.T, Agent) + }{ + { + name: "GetEnvVariables", + assert: func(t *testing.T, agent Agent) { + vars := agent.GetEnvVariables() + if len(strings.Split(vars, " ")) != 2 { + t.Fatalf("not enough variables") + } + + match, err := regexp.MatchString(`SSH_AGENT_PID=[0-9]+ SSH_AUTH_SOCK=\S*`, vars) + if err != nil || !match { + t.Fatalf("format does not match") + } + }, + }, + { + name: "Stop", + assert: func(t *testing.T, agent Agent) { + if err := agent.Stop(); err != nil { + t.Errorf("failed to stop agent: %s", err) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.assert(t, a) + }) + } +} diff --git a/ssh/scp.go b/ssh/scp.go index c4781b94..1f0b26bd 100644 --- a/ssh/scp.go +++ b/ssh/scp.go @@ -17,7 +17,7 @@ import ( // CopyFrom copies one or more files using SCP from remote host // and returns the paths of files that were successfully copied. -func CopyFrom(args SSHArgs, rootDir string, sourcePath string) error { +func CopyFrom(args SSHArgs, agent Agent, rootDir string, sourcePath string) error { e := echo.New() prog := e.Prog.Avail("scp") if len(prog) == 0 { @@ -44,6 +44,11 @@ func CopyFrom(args SSHArgs, rootDir string, sourcePath string) error { effectiveCmd := fmt.Sprintf(`%s %s`, sshCmd, targetPath) logrus.Debug("scp: ", effectiveCmd) + if agent != nil { + logrus.Debugf("Adding agent info: %s", agent.GetEnvVariables()) + e = e.Env(agent.GetEnvVariables()) + } + maxRetries := args.MaxRetries if maxRetries == 0 { maxRetries = 10 diff --git a/ssh/scp_test.go b/ssh/scp_test.go index 84d790d3..1f969fca 100644 --- a/ssh/scp_test.go +++ b/ssh/scp_test.go @@ -54,7 +54,7 @@ func TestCopy(t *testing.T) { MakeTestSSHFile(t, test.sshArgs, file, content) } - if err := CopyFrom(test.sshArgs, support.TmpDirRoot(), test.srcFile); err != nil { + if err := CopyFrom(test.sshArgs, nil, support.TmpDirRoot(), test.srcFile); err != nil { t.Fatal(err) } diff --git a/ssh/ssh.go b/ssh/ssh.go index c41d9cd5..641f09b0 100644 --- a/ssh/ssh.go +++ b/ssh/ssh.go @@ -30,8 +30,8 @@ type SSHArgs struct { } // Run runs a command over SSH and returns the result as a string -func Run(args SSHArgs, cmd string) (string, error) { - reader, err := sshRunProc(args, cmd) +func Run(args SSHArgs, agent Agent, cmd string) (string, error) { + reader, err := sshRunProc(args, agent, cmd) if err != nil { return "", err } @@ -43,11 +43,11 @@ func Run(args SSHArgs, cmd string) (string, error) { } // RunRead runs a command over SSH and returns an io.Reader for stdout/stderr -func RunRead(args SSHArgs, cmd string) (io.Reader, error) { - return sshRunProc(args, cmd) +func RunRead(args SSHArgs, agent Agent, cmd string) (io.Reader, error) { + return sshRunProc(args, agent, cmd) } -func sshRunProc(args SSHArgs, cmd string) (io.Reader, error) { +func sshRunProc(args SSHArgs, agent Agent, cmd string) (io.Reader, error) { e := echo.New() prog := e.Prog.Avail("ssh") if len(prog) == 0 { @@ -61,6 +61,11 @@ func sshRunProc(args SSHArgs, cmd string) (io.Reader, error) { effectiveCmd := fmt.Sprintf(`%s "%s"`, sshCmd, cmd) logrus.Debug("ssh.run: ", effectiveCmd) + if agent != nil { + logrus.Debugf("Adding agent info: %s", agent.GetEnvVariables()) + e = e.Env(agent.GetEnvVariables()) + } + var proc *echo.Proc maxRetries := args.MaxRetries if maxRetries == 0 { diff --git a/ssh/ssh_test.go b/ssh/ssh_test.go index 0668a3c8..490973cc 100644 --- a/ssh/ssh_test.go +++ b/ssh/ssh_test.go @@ -26,7 +26,7 @@ func TestRun(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - expected, err := Run(test.args, test.cmd) + expected, err := Run(test.args, nil, test.cmd) if err != nil { t.Fatal(err) } @@ -54,7 +54,7 @@ func TestRunRead(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - reader, err := RunRead(test.args, test.cmd) + reader, err := RunRead(test.args, nil, test.cmd) if err != nil { t.Fatal(err) } diff --git a/ssh/test_support.go b/ssh/test_support.go index 22f986e9..483ce2dd 100644 --- a/ssh/test_support.go +++ b/ssh/test_support.go @@ -43,12 +43,12 @@ import ( func makeTestSSHDir(t *testing.T, args SSHArgs, dir string) { t.Logf("creating test dir over SSH: %s", dir) - _, err := Run(args, fmt.Sprintf(`mkdir -p %s`, dir)) + _, err := Run(args, nil, fmt.Sprintf(`mkdir -p %s`, dir)) if err != nil { t.Fatal(err) } // validate - result, _ := Run(args, fmt.Sprintf(`ls %s`, dir)) + result, _ := Run(args, nil, fmt.Sprintf(`ls %s`, dir)) t.Logf("dir created: %s", result) } @@ -59,18 +59,18 @@ func MakeTestSSHFile(t *testing.T, args SSHArgs, filePath, content string) { } t.Logf("creating test file over SSH: %s", filePath) - _, err := Run(args, fmt.Sprintf(`echo '%s' > %s`, content, filePath)) + _, err := Run(args, nil, fmt.Sprintf(`echo '%s' > %s`, content, filePath)) if err != nil { t.Fatal(err) } - result, _ := Run(args, fmt.Sprintf(`ls %s`, filePath)) + result, _ := Run(args, nil, fmt.Sprintf(`ls %s`, filePath)) t.Logf("file created: %s", result) } func RemoveTestSSHFile(t *testing.T, args SSHArgs, fileName string) { t.Logf("removing test file over SSH: %s", fileName) - _, err := Run(args, fmt.Sprintf(`rm -rf %s`, fileName)) + _, err := Run(args, nil, fmt.Sprintf(`rm -rf %s`, fileName)) if err != nil { t.Fatal(err) } diff --git a/starlark/capa_provider.go b/starlark/capa_provider.go index bde80814..9fea68a0 100644 --- a/starlark/capa_provider.go +++ b/starlark/capa_provider.go @@ -75,7 +75,7 @@ func CapaProviderFn(thread *starlark.Thread, _ *starlark.Builtin, args starlark. // dictionary for capa provider struct capaProviderDict := starlark.StringDict{ - "kind": starlark.String(identifiers.capvProvider), + "kind": starlark.String(identifiers.capaProvider), "transport": starlark.String("ssh"), "kube_config": starlark.String(providerConfigPath), } diff --git a/starlark/capture.go b/starlark/capture.go index fd5991cb..132e84b1 100644 --- a/starlark/capture.go +++ b/starlark/capture.go @@ -10,11 +10,11 @@ import ( "path/filepath" "strings" + "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/vmware-tanzu/crash-diagnostics/ssh" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" - - "github.com/vmware-tanzu/crash-diagnostics/ssh" ) // captureFunc is a built-in starlark function that runs a provided command and @@ -58,7 +58,16 @@ func captureFunc(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tup resources = res } - results, err := execCapture(cmdStr, workdir, fileName, desc, resources) + var agent ssh.Agent + var ok bool + if agentVal := thread.Local(identifiers.sshAgent); agentVal != nil { + agent, ok = agentVal.(ssh.Agent) + if !ok { + return starlark.None, errors.New("unable to fetch ssh-agent") + } + } + + results, err := execCapture(cmdStr, workdir, fileName, desc, agent, resources) if err != nil { return starlark.None, fmt.Errorf("%s: %s", identifiers.capture, err) } @@ -75,7 +84,7 @@ func captureFunc(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tup return starlark.NewList(resultList), nil } -func execCapture(cmdStr, rootPath, fileName, desc string, resources *starlark.List) ([]commandResult, error) { +func execCapture(cmdStr, rootPath, fileName, desc string, agent ssh.Agent, resources *starlark.List) ([]commandResult, error) { if resources == nil { return nil, fmt.Errorf("%s: missing resources", identifiers.capture) } @@ -110,7 +119,7 @@ func execCapture(cmdStr, rootPath, fileName, desc string, resources *starlark.Li switch { case string(kind) == identifiers.hostResource && string(transport) == "ssh": - result, err := execCaptureSSH(host, cmdStr, rootDir, fileName, desc, res) + result, err := execCaptureSSH(host, cmdStr, rootDir, fileName, desc, agent, res) if err != nil { logrus.Errorf("%s failed: cmd=[%s]: %s", identifiers.capture, cmdStr, err) } @@ -124,7 +133,7 @@ func execCapture(cmdStr, rootPath, fileName, desc string, resources *starlark.Li return results, nil } -func execCaptureSSH(host, cmdStr, rootDir, fileName, desc string, res *starlarkstruct.Struct) (commandResult, error) { +func execCaptureSSH(host, cmdStr, rootDir, fileName, desc string, agent ssh.Agent, res *starlarkstruct.Struct) (commandResult, error) { sshCfg := starlarkstruct.FromKeywords(starlarkstruct.Default, makeDefaultSSHConfig()) if val, err := res.Attr(identifiers.sshCfg); err == nil { if cfg, ok := val.(*starlarkstruct.Struct); ok { @@ -151,7 +160,7 @@ func execCaptureSSH(host, cmdStr, rootDir, fileName, desc string, res *starlarks logrus.Debugf("%s: capturing output of [cmd=%s] => [%s] from %s using ssh", identifiers.capture, cmdStr, filePath, args.Host) - reader, err := ssh.RunRead(args, cmdStr) + reader, err := ssh.RunRead(args, agent, cmdStr) if err != nil { logrus.Errorf("%s failed: %s", identifiers.capture, err) if err := captureOutput(strings.NewReader(err.Error()), filePath, fmt.Sprintf("%s: failed", cmdStr)); err != nil { diff --git a/starlark/copy_from.go b/starlark/copy_from.go index 874352c6..26b0dfef 100644 --- a/starlark/copy_from.go +++ b/starlark/copy_from.go @@ -8,11 +8,11 @@ import ( "os" "path/filepath" + "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/vmware-tanzu/crash-diagnostics/ssh" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" - - "github.com/vmware-tanzu/crash-diagnostics/ssh" ) // copyFromFunc is a built-in starlark function that copies file resources from @@ -57,7 +57,16 @@ func copyFromFunc(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tu resources = res } - results, err := execCopy(workdir, sourcePath, resources) + var agent ssh.Agent + var ok bool + if agentVal := thread.Local(identifiers.sshAgent); agentVal != nil { + agent, ok = agentVal.(ssh.Agent) + if !ok { + return starlark.None, errors.New("unable to fetch ssh-agent") + } + } + + results, err := execCopy(workdir, sourcePath, agent, resources) if err != nil { return starlark.None, fmt.Errorf("%s: %s", identifiers.copyFrom, err) } @@ -74,7 +83,7 @@ func copyFromFunc(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tu return starlark.NewList(resultList), nil } -func execCopy(rootPath string, path string, resources *starlark.List) ([]commandResult, error) { +func execCopy(rootPath string, path string, agent ssh.Agent, resources *starlark.List) ([]commandResult, error) { if resources == nil { return nil, fmt.Errorf("%s: missing resources", identifiers.copyFrom) } @@ -108,7 +117,7 @@ func execCopy(rootPath string, path string, resources *starlark.List) ([]command switch { case string(kind) == identifiers.hostResource && string(transport) == "ssh": - result, err := execCopySCP(host, rootDir, path, res) + result, err := execCopySCP(host, rootDir, path, agent, res) if err != nil { logrus.Errorf("%s: failed to copyFrom %s: %s", identifiers.copyFrom, path, err) } @@ -122,7 +131,7 @@ func execCopy(rootPath string, path string, resources *starlark.List) ([]command return results, nil } -func execCopySCP(host, rootDir, path string, res *starlarkstruct.Struct) (commandResult, error) { +func execCopySCP(host, rootDir, path string, agent ssh.Agent, res *starlarkstruct.Struct) (commandResult, error) { sshCfg := starlarkstruct.FromKeywords(starlarkstruct.Default, makeDefaultSSHConfig()) if val, err := res.Attr(identifiers.sshCfg); err == nil { if cfg, ok := val.(*starlarkstruct.Struct); ok { @@ -141,6 +150,6 @@ func execCopySCP(host, rootDir, path string, res *starlarkstruct.Struct) (comman return commandResult{}, err } - err = ssh.CopyFrom(args, rootDir, path) + err = ssh.CopyFrom(args, agent, rootDir, path) return commandResult{resource: args.Host, result: filepath.Join(rootDir, path), err: err}, err } diff --git a/starlark/crashd_config.go b/starlark/crashd_config.go index b3221f35..2438065d 100644 --- a/starlark/crashd_config.go +++ b/starlark/crashd_config.go @@ -7,7 +7,9 @@ import ( "fmt" "os" + "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/vmware-tanzu/crash-diagnostics/ssh" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" ) @@ -29,10 +31,11 @@ func addDefaultCrashdConf(thread *starlark.Thread) error { return nil } -// crashConfig is built-in starlark function that saves and returns the kwargs as a struct value. +// crashdConfigFn is built-in starlark function that saves and returns the kwargs as a struct value. // Starlark format: crashd_config(workdir=path, default_shell=shellpath, requires=["command0",...,"commandN"]) -func crashdConfigFn(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { +func crashdConfigFn(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { var workdir, gid, uid, defaultShell string + var useSSHAgent bool requires := starlark.NewList([]starlark.Value{}) if err := starlark.UnpackArgs( @@ -42,6 +45,7 @@ func crashdConfigFn(thread *starlark.Thread, b *starlark.Builtin, args starlark. "uid?", &uid, "default_shell?", &defaultShell, "requires?", &requires, + "use_ssh_agent?", &useSSHAgent, ); err != nil { return starlark.None, fmt.Errorf("%s: %s", identifiers.crashdCfg, err) } @@ -63,6 +67,16 @@ func crashdConfigFn(thread *starlark.Thread, b *starlark.Builtin, args starlark. return starlark.None, fmt.Errorf("%s: %s", identifiers.crashdCfg, err) } + if useSSHAgent { + agent, err := ssh.StartAgent() + if err != nil { + return starlark.None, errors.Wrap(err, "failed to start ssh agent") + } + + // sets the ssh_agent variable in the current Starlark thread + thread.SetLocal(identifiers.sshAgent, agent) + } + cfgStruct := starlarkstruct.FromStringDict(starlark.String(identifiers.crashdCfg), starlark.StringDict{ "workdir": starlark.String(workdir), "gid": starlark.String(gid), diff --git a/starlark/crashd_config_test.go b/starlark/crashd_config_test.go index e29122be..b0d2613c 100644 --- a/starlark/crashd_config_test.go +++ b/starlark/crashd_config_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/vmware-tanzu/crash-diagnostics/ssh" "go.starlark.net/starlarkstruct" ) @@ -92,6 +93,26 @@ func testCrashdConfigFunc(t *testing.T) { } }, }, + + { + name: "crash_config with use-ssh-agent", + script: `crashd_config(workdir="fooval", default_shell="barval", use_ssh_agent=True)`, + eval: func(t *testing.T, script string) { + defer os.RemoveAll("fooval") + exe := New() + if err := exe.Exec("test.star", strings.NewReader(script)); err != nil { + t.Fatal(err) + } + data := exe.thread.Local(identifiers.sshAgent) + if data == nil { + t.Fatal("use_ssh_agent identifier not saved in thread local") + } + agent, ok := data.(ssh.Agent) + if !ok || agent == nil { + t.Fatal("ssh agent should have been started") + } + }, + }, } for _, test := range tests { diff --git a/starlark/resources.go b/starlark/resources.go index 112195d6..ad0d8672 100644 --- a/starlark/resources.go +++ b/starlark/resources.go @@ -65,7 +65,7 @@ func enum(provider *starlarkstruct.Struct) (*starlark.List, error) { kind := trimQuotes(kindVal.String()) switch kind { - case identifiers.hostListProvider, identifiers.kubeNodesProvider, identifiers.capvProvider: + case identifiers.hostListProvider, identifiers.kubeNodesProvider, identifiers.capvProvider, identifiers.capaProvider: hosts, err := provider.Attr("hosts") if err != nil { return nil, fmt.Errorf("hosts not found in %s", identifiers.hostListProvider) diff --git a/starlark/run.go b/starlark/run.go index 6ae728e7..30bde028 100644 --- a/starlark/run.go +++ b/starlark/run.go @@ -6,6 +6,7 @@ package starlark import ( "fmt" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" @@ -40,7 +41,7 @@ func (r commandResult) toStarlarkStruct() *starlarkstruct.Struct { // about the executed command on the provided compute resources. If resources // is not provided, runFunc uses the default resources found in the starlark thread. // Starlark format: run(cmd="command" [,resources=resources]) -func runFunc(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { +func runFunc(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { var cmdStr string var resources *starlark.List if err := starlark.UnpackArgs( @@ -63,7 +64,16 @@ func runFunc(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, resources = resList } - results, err := execRun(cmdStr, resources) + var agent ssh.Agent + var ok bool + if agentVal := thread.Local(identifiers.sshAgent); agentVal != nil { + agent, ok = agentVal.(ssh.Agent) + if !ok { + return starlark.None, errors.New("unable to fetch ssh-agent") + } + } + + results, err := execRun(cmdStr, agent, resources) if err != nil { return starlark.None, err } @@ -80,7 +90,7 @@ func runFunc(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, return starlark.NewList(resultList), nil } -func execRun(cmdStr string, resources *starlark.List) ([]commandResult, error) { +func execRun(cmdStr string, agent ssh.Agent, resources *starlark.List) ([]commandResult, error) { if resources == nil { return nil, fmt.Errorf("%s: missing resources", identifiers.run) } @@ -108,7 +118,7 @@ func execRun(cmdStr string, resources *starlark.List) ([]commandResult, error) { switch { case string(kind) == identifiers.hostResource && string(transport) == "ssh": - result, err := execRunSSH(cmdStr, res) + result, err := execRunSSH(cmdStr, agent, res) if err != nil { logrus.Error(err) continue @@ -124,7 +134,7 @@ func execRun(cmdStr string, resources *starlark.List) ([]commandResult, error) { } // execRunSSH executes `run` command for a Host Resource using SSH -func execRunSSH(cmdStr string, res *starlarkstruct.Struct) (commandResult, error) { +func execRunSSH(cmdStr string, agent ssh.Agent, res *starlarkstruct.Struct) (commandResult, error) { sshCfg := starlarkstruct.FromKeywords(starlarkstruct.Default, makeDefaultSSHConfig()) if val, err := res.Attr(identifiers.sshCfg); err == nil { if cfg, ok := val.(*starlarkstruct.Struct); ok { @@ -149,7 +159,7 @@ func execRunSSH(cmdStr string, res *starlarkstruct.Struct) (commandResult, error args.Host = string(host) logrus.Debugf("%s: executing command on %s using ssh: [%s]", identifiers.run, args.Host, cmdStr) - cmdResult, err := ssh.Run(args, cmdStr) + cmdResult, err := ssh.Run(args, agent, cmdStr) return commandResult{resource: args.Host, result: cmdResult, err: err}, nil } diff --git a/starlark/ssh_config.go b/starlark/ssh_config.go index 63c278aa..dc21f524 100644 --- a/starlark/ssh_config.go +++ b/starlark/ssh_config.go @@ -6,24 +6,27 @@ package starlark import ( "fmt" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/vmware-tanzu/crash-diagnostics/ssh" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" ) -// addDefaultSshConf initalizes a Starlark Dict with default +// addDefaultSshConf initializes a Starlark Dict with default // ssh_config configuration data func addDefaultSSHConf(thread *starlark.Thread) error { args := makeDefaultSSHConfig() - _, err := sshConfigFn(thread, nil, nil, args) + _, err := SshConfigFn(thread, nil, nil, args) if err != nil { return err } return nil } -// sshConfigFn is the backing built-in fn that saves and returns its argument as struct value. +// SshConfigFn is the backing built-in fn that saves and returns its argument as struct value. // Starlark format: ssh_config(username=name[, port][, private_key_path][,max_retries][,conn_timeout][,jump_user][,jump_host]) -func sshConfigFn(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { +func SshConfigFn(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { var uname, port, pkPath, jUser, jHost string var maxRetries, connTimeout int @@ -57,6 +60,17 @@ func sshConfigFn(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, k pkPath = defaults.pkPath } + if agentVal := thread.Local(identifiers.sshAgent); agentVal != nil { + agent, ok := agentVal.(ssh.Agent) + if !ok { + return starlark.None, errors.New("unable to fetch ssh-agent") + } + logrus.Debugf("adding key %s to ssh-agent", pkPath) + if err := agent.AddKey(pkPath); err != nil { + return starlark.None, errors.Wrapf(err, "unable to add key %s", pkPath) + } + } + sshConfigDict := starlark.StringDict{ "username": starlark.String(uname), "port": starlark.String(port), @@ -77,10 +91,10 @@ func sshConfigFn(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, k func makeDefaultSSHConfig() []starlark.Tuple { return []starlark.Tuple{ - starlark.Tuple{starlark.String("username"), starlark.String(getUsername())}, - starlark.Tuple{starlark.String("port"), starlark.String("22")}, - starlark.Tuple{starlark.String("private_key_path"), starlark.String(defaults.pkPath)}, - starlark.Tuple{starlark.String("max_retries"), starlark.MakeInt(defaults.connRetries)}, - starlark.Tuple{starlark.String("conn_timeout"), starlark.MakeInt(defaults.connTimeout)}, + {starlark.String("username"), starlark.String(getUsername())}, + {starlark.String("port"), starlark.String("22")}, + {starlark.String("private_key_path"), starlark.String(defaults.pkPath)}, + {starlark.String("max_retries"), starlark.MakeInt(defaults.connRetries)}, + {starlark.String("conn_timeout"), starlark.MakeInt(defaults.connTimeout)}, } } diff --git a/starlark/starlark_exec.go b/starlark/starlark_exec.go index 2428c573..e6ba0c6a 100644 --- a/starlark/starlark_exec.go +++ b/starlark/starlark_exec.go @@ -8,6 +8,8 @@ import ( "fmt" "io" + "github.com/sirupsen/logrus" + "github.com/vmware-tanzu/crash-diagnostics/ssh" "go.starlark.net/starlark" ) @@ -45,6 +47,19 @@ func (e *Executor) Exec(name string, source io.Reader) error { } e.result = result + // fetch and stop the instance of ssh-agent, if any + if agentVal := e.thread.Local(identifiers.sshAgent); agentVal != nil { + logrus.Debug("stopping ssh-agent") + agent, ok := agentVal.(ssh.Agent) + if !ok { + logrus.Warn("error fetching ssh-agent") + } else { + if e := agent.Stop(); e != nil { + logrus.Warnf("failed to stop ssh-agent: %v", e) + } + } + } + return nil } @@ -76,7 +91,7 @@ func newPredeclareds() starlark.StringDict { return starlark.StringDict{ identifiers.os: setupOSStruct(), identifiers.crashdCfg: starlark.NewBuiltin(identifiers.crashdCfg, crashdConfigFn), - identifiers.sshCfg: starlark.NewBuiltin(identifiers.sshCfg, sshConfigFn), + identifiers.sshCfg: starlark.NewBuiltin(identifiers.sshCfg, SshConfigFn), identifiers.hostListProvider: starlark.NewBuiltin(identifiers.hostListProvider, hostListProvider), identifiers.resources: starlark.NewBuiltin(identifiers.resources, resourcesFunc), identifiers.archive: starlark.NewBuiltin(identifiers.archive, archiveFunc), @@ -90,7 +105,7 @@ func newPredeclareds() starlark.StringDict { identifiers.kubeGet: starlark.NewBuiltin(identifiers.kubeGet, KubeGetFn), identifiers.kubeNodesProvider: starlark.NewBuiltin(identifiers.kubeNodesProvider, KubeNodesProviderFn), identifiers.capvProvider: starlark.NewBuiltin(identifiers.capvProvider, CapvProviderFn), - identifiers.capaProvider: starlark.NewBuiltin(identifiers.capvProvider, CapaProviderFn), + identifiers.capaProvider: starlark.NewBuiltin(identifiers.capaProvider, CapaProviderFn), identifiers.setDefaults: starlark.NewBuiltin(identifiers.setDefaults, SetDefaultsFunc), } } diff --git a/starlark/support.go b/starlark/support.go index 14078b6c..54f257b5 100644 --- a/starlark/support.go +++ b/starlark/support.go @@ -47,6 +47,8 @@ var ( kubeNodesProvider string capvProvider string capaProvider string + + sshAgent string }{ crashdCfg: "crashd_config", kubeCfg: "kube_config", @@ -76,6 +78,8 @@ var ( kubeNodesProvider: "kube_nodes_provider", capvProvider: "capv_provider", capaProvider: "capa_provider", + + sshAgent: "crashd_ssh_agent", } defaults = struct {