Skip to content

Commit

Permalink
expanded ssh_config parameters for qemu+ssh uri option (dmacvicar#1059)
Browse files Browse the repository at this point in the history
* bump ssh_config to stable version which contains GetAll call
* refactor dialSSH and break out dialHost as support function
* move authentication parsing to per host part of loop

this allows different hosts (jump hosts) to have different identity files
specified

* implement per Host identity file lookup
* fix tilde (~) based home directory notation for convenience
* updated go.sum
* cleanup log outputs
* remove unnecessary local variable
* make use of net package URI building to support correct ipv6

as per commit from MaxMatti:
dmacvicar@1152bdd

* correctly use host:port format when dialing bastion host
* put quotes around target in case it is empty
* if the hostname override isn't present, simply use target name
* add log output for port override
* add default host key algorithm
* move port configuration earlier so that hostkey callback works right
the hostKeyCallback makes use of the SSH port and fails if a custom ssh port is
being used by the host

* cleanup log output, add error handling for dial host
* add support for sshconfig based known hosts file behaviour
* integrate HostKeyAlgorithms ssh_config option
* move dial host impl so that bastion hosts have same features
* add comments
* use a more modern default host key

this value was chosen as the lowest RSA available by default on a debian build
running OpenSSH_9.2 and works out of the box for most hosts tested by authority.
Any older systems can specifically set their key algorithms in .ssh/config

* update auth method parse to allow for multiple private ssh keys
* use a list of hostKeyAlgorithms instead of just one default
* use camelCase to match go coding styles
* use join instead of replace for a more predictable outcome

replace could have resulted in weird behaviour such as "some~path" becoming
incorrectly mangled

* remove log.Fatal and let the upper layer deal with the logging
* change magic number to constant
* code formatting to match go coding style
* add missing import for filepath module (0da4763)
* lint fixes
  • Loading branch information
memetb authored and dmacvicar committed Sep 28, 2024
1 parent 621d48d commit dc44730
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 48 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/google/uuid v1.3.0
github.com/hashicorp/terraform-plugin-sdk/v2 v2.24.1
github.com/hooklift/iso9660 v1.0.0
github.com/kevinburke/ssh_config v0.0.0-20201106050909-4977a11b4351
github.com/kevinburke/ssh_config v1.2.0
github.com/mattn/goveralls v0.0.11
github.com/stretchr/testify v1.8.1
golang.org/x/crypto v0.21.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVY
github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
github.com/kevinburke/ssh_config v0.0.0-20201106050909-4977a11b4351 h1:DowS9hvgyYSX4TO5NpyC606/Z4SxnNYbT+WX27or6Ck=
github.com/kevinburke/ssh_config v0.0.0-20201106050909-4977a11b4351/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4=
github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
github.com/keybase/go-crypto v0.0.0-20161004153544-93f5b35093ba/go.mod h1:ghbZscTyKdM07+Fw3KSi0hcJm+AlEUWj8QLlPtijN/M=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
Expand Down
237 changes: 190 additions & 47 deletions libvirt/uri/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"os"
"os/user"
"path/filepath"
"strings"

"github.com/kevinburke/ssh_config"
Expand All @@ -15,26 +16,48 @@ import (
)

const (
maxHostHops = 10
defaultSSHPort = "22"
defaultSSHKeyPath = "${HOME}/.ssh/id_rsa"
defaultSSHKnownHostsPath = "${HOME}/.ssh/known_hosts"
defaultSSHConfigFile = "${HOME}/.ssh/config"
defaultSSHAuthMethods = "agent,privkey"
)

func (u *ConnectionURI) parseAuthMethods() []ssh.AuthMethod {
func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Config) []ssh.AuthMethod {
q := u.Query()

authMethods := q.Get("sshauth")
if authMethods == "" {
authMethods = defaultSSHAuthMethods
}

log.Printf("[DEBUG] auth methods for %v: %v", target, authMethods)

// keyfile order of precedence:
// 1. load uri encoded keyfile
// 2. load override as specified in ssh config
// 3. load default ssh keyfile path
sshKeyPaths := []string {}
sshKeyPath := q.Get("keyfile")
if sshKeyPath == "" {
sshKeyPath = defaultSSHKeyPath
if sshKeyPath != "" {
sshKeyPaths = append(sshKeyPaths, sshKeyPath)
}

keyPaths, err := sshcfg.GetAll(target, "IdentityFile")
if err != nil {
log.Printf("[WARN] unable to get IdentityFile values - ignoring")
} else {
sshKeyPaths = append(sshKeyPaths, keyPaths...)
}

if len(keyPaths) == 0 {
log.Printf("[DEBUG] found no ssh keys, using default keypath")
sshKeyPaths = []string{defaultSSHKeyPath}
}

log.Printf("[DEBUG] ssh identity files for host '%s': %s", target, sshKeyPaths);

auths := strings.Split(authMethods, ",")
result := make([]ssh.AuthMethod, 0)
for _, v := range auths {
Expand All @@ -53,17 +76,27 @@ func (u *ConnectionURI) parseAuthMethods() []ssh.AuthMethod {
agentClient := agent.NewClient(conn)
result = append(result, ssh.PublicKeysCallback(agentClient.Signers))
case "privkey":
sshKey, err := os.ReadFile(os.ExpandEnv(sshKeyPath))
if err != nil {
log.Printf("[ERROR] Failed to read ssh key: %v", err)
continue
}
for _, keypath := range sshKeyPaths {
log.Printf("[DEBUG] Reading ssh key '%s'", keypath)
path := os.ExpandEnv(keypath)
if strings.HasPrefix(path, "~/") {
home, err := os.UserHomeDir()
if err == nil {
path = filepath.Join(home, path[2:])
}
}
sshKey, err := os.ReadFile(path)
if err != nil {
log.Printf("[ERROR] Failed to read ssh key '%s': %v", keypath, err)
continue
}

signer, err := ssh.ParsePrivateKey(sshKey)
if err != nil {
log.Printf("[ERROR] Failed to parse ssh key: %v", err)
signer, err := ssh.ParsePrivateKey(sshKey)
if err != nil {
log.Printf("[ERROR] Failed to parse ssh key: %v", err)
}
result = append(result, ssh.PublicKeys(signer))
}
result = append(result, ssh.PublicKeys(signer))
case "ssh-password":
if sshPassword, ok := u.User.Password(); ok {
result = append(result, ssh.Password(sshPassword))
Expand All @@ -79,6 +112,8 @@ func (u *ConnectionURI) parseAuthMethods() []ssh.AuthMethod {
return result
}

// construct the whole ssh connection, which can consist of multiple hops if using proxy jumps,
// the ssh configuration file is loaded once and passed along to each host connection.
func (u *ConnectionURI) dialSSH() (net.Conn, error) {
sshConfigFile, err := os.Open(os.ExpandEnv(defaultSSHConfigFile))
if err != nil {
Expand All @@ -90,74 +125,182 @@ func (u *ConnectionURI) dialSSH() (net.Conn, error) {
log.Printf("[WARN] Failed to parse ssh config file: %v", err)
}

authMethods := u.parseAuthMethods()
if len(authMethods) < 1 {
return nil, fmt.Errorf("could not configure SSH authentication methods")
// configuration loaded, build tunnel
sshClient, err := u.dialHost(u.Host, sshcfg, 0)
if err != nil {
return nil, err
}

// tunnel established, connect to the libvirt unix socket to communicate
// e.g. /var/run/libvirt/libvirt-sock
address := u.Query().Get("socket")
if address == "" {
address = defaultUnixSock
}

c, err := sshClient.Dial("unix", address)
if err != nil {
return nil, fmt.Errorf("failed to connect to libvirt on the remote host: %w", err)
}

return c, nil
}

func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth int) (*ssh.Client, error) {

if depth > maxHostHops {
return nil, fmt.Errorf("[ERROR] dialHost failed: max tunnel depth of 10 reached")
}

log.Printf("[INFO] establishing ssh connection to '%s'", target);

q := u.Query()

port := u.Port()
if port == "" {
port = defaultSSHPort
} else {
log.Printf("[DEBUG] ssh Port is overridden to: '%s'", port);
}

hostName, err := sshcfg.Get(target, "HostName")
if err == nil {
if hostName == "" {
hostName = target;
} else {
log.Printf("[DEBUG] HostName is overridden to: '%s'", hostName);
}
}

// we must check for knownhosts and verification for each host we connect to.
// the query string values have higher precedence to local configs
knownHostsPath := q.Get("knownhosts")
knownHostsVerify := q.Get("known_hosts_verify")
doVerify := q.Get("no_verify") == ""
skipVerify := q.Has("no_verify")

if knownHostsVerify == "ignore" {
doVerify = false
skipVerify = true
} else {
strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking")
if err != nil && strictCheck == "yes" {
skipVerify = false
}
}

if knownHostsPath == "" {
knownHostsPath = defaultSSHKnownHostsPath
knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile")
if err == nil && knownHosts != "" {
knownHostsPath = knownHosts
} else {
knownHostsPath = defaultSSHKnownHostsPath
}
}

hostKeyCallback := ssh.InsecureIgnoreHostKey()
if doVerify {
cb, err := knownhosts.New(os.ExpandEnv(knownHostsPath))
hostKeyAlgorithms := []string{ // https://github.com/golang/go/issues/29286
// this can be solved using https://github.com/skeema/knownhosts/tree/main
// there is an open issue requiring attention
ssh.KeyAlgoED25519,
ssh.KeyAlgoRSA,
ssh.KeyAlgoRSASHA256,
ssh.KeyAlgoRSASHA512,
ssh.KeyAlgoSKECDSA256,
ssh.KeyAlgoSKED25519,
ssh.KeyAlgoECDSA256,
ssh.KeyAlgoECDSA384,
ssh.KeyAlgoECDSA521,
}
if !skipVerify {
kh, err := knownhosts.New(os.ExpandEnv(knownHostsPath))
if err != nil {
return nil, fmt.Errorf("failed to read ssh known hosts: %w", err)
}
hostKeyCallback = cb
}
log.Printf("[DEBUG] Using known hosts file '%s' for target '%s'", os.ExpandEnv(knownHostsPath), target)

username := u.User.Username()
if username == "" {
sshu, err := sshcfg.Get(u.Host, "User")
log.Printf("[DEBUG] SSH User: %v", sshu)
if err != nil {
log.Printf("[DEBUG] ssh user: system username")
u, err := user.Current()
hostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error {
err := kh(net.JoinHostPort(hostName, port), remote, key)
if err != nil {
return nil, fmt.Errorf("unable to get username: %w", err)
log.Printf("Host key verification failed for host '%s' (%s) %v: %v", hostName, remote, key, err)
}
sshu = u.Username
return err
}
username = sshu

keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms")
if err == nil && keyAlgs != "" {
log.Printf("Got host key algorithms '%s'", keyAlgs)
hostKeyAlgorithms = strings.Split(keyAlgs, ",")
}

}

cfg := ssh.ClientConfig{
User: username,
User: u.User.Username(),
HostKeyCallback: hostKeyCallback,
Auth: authMethods,
HostKeyAlgorithms: hostKeyAlgorithms,
Timeout: dialTimeout,
}

port := u.Port()
if port == "" {
port = defaultSSHPort
proxy, err := sshcfg.Get(target, "ProxyCommand")
if err == nil && proxy != "" {
log.Printf("[WARNING] unsupported ssh ProxyCommand '%v'", proxy)
}

sshClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%s", u.Hostname(), port), &cfg)
if err != nil {
return nil, err
proxy, err = sshcfg.Get(target, "ProxyJump")
var bastion *ssh.Client
if err == nil && proxy != "" {
log.Printf("[DEBUG] found ProxyJump '%v'", proxy)

// this is a proxy jump: we recurse into that proxy
bastion, err = u.dialHost(proxy, sshcfg, depth + 1)
if err != nil {
return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err)
}
}

address := q.Get("socket")
if address == "" {
address = defaultUnixSock
if cfg.User == "" {
sshu, err := sshcfg.Get(target, "User")
log.Printf("[DEBUG] SSH User for target '%v' is '%v'", target, sshu)
if err != nil {
log.Printf("[DEBUG] ssh user: using current login")
u, err := user.Current()
if err != nil {
return nil, fmt.Errorf("unable to get username: %w", err)
}
sshu = u.Username
}
cfg.User = sshu
}

c, err := sshClient.Dial("unix", address)
if err != nil {
return nil, fmt.Errorf("failed to connect to libvirt on the remote host: %w", err)
cfg.Auth = u.parseAuthMethods(target, sshcfg)
if len(cfg.Auth) < 1 {
return nil, fmt.Errorf("could not configure SSH authentication methods")
}

return c, nil
if (bastion != nil) {
// if this is a proxied connection, we want to dial through the bastion host
log.Printf("[INFO] SSH connecting to '%v' (%v) through bastion host '%v'", target, hostName, proxy)
// Dial a connection to the service host, from the bastion
conn, err := bastion.Dial("tcp", net.JoinHostPort(hostName, port))
if err != nil {
return nil, fmt.Errorf("failed to connect to remote host '%v': %w", target, err)
}

ncc, chans, reqs, err := ssh.NewClientConn(conn, target, &cfg)
if err != nil {
return nil, fmt.Errorf("failed to connect to remote host '%v': %w", target, err)
}

sClient := ssh.NewClient(ncc, chans, reqs)
return sClient, nil

} else {
// this is a direct connection to the target host
log.Printf("[INFO] SSH connecting to '%v' (%v)", target, hostName)
conn,err := ssh.Dial("tcp", net.JoinHostPort(hostName, port), &cfg)

if err != nil {
return nil, fmt.Errorf("failed to connect to remote host '%v': %w", target, err)
}
return conn, nil
}
}

0 comments on commit dc44730

Please sign in to comment.