diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dcc1d9..897e02c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ ## 0.2.0 / TBD * Update to 1.15 and update Go module dependencies +* Add `known_hosts` configuration option to allow verifying SSH hosts against known hosts +* Add `host_key_algorithms` configuration option to specify host key algorithms to use when verifying SSH hosts ## 0.1.1 / 2020-04-01 diff --git a/README.md b/README.md index 496561a..53f823d 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,15 @@ modules: password: id: prometheus password: secret + verify: + user: prometheus + private_key: /home/prometheus/.ssh/id_rsa + known_hosts: /etc/ssh/ssh_known_hosts + host_key_algorithms: + - ssh-rsa + command: uptime + command_expect: "load average" + timeout: 5 ``` Example with curl would query host1 with the password module and host2 with the default module. @@ -47,6 +56,9 @@ Configuration options for each module: * `user` - The username for the SSH connection * `password` - The password for the SSH connection, required if `private_key` is not specified * `private_key` - The SSH private key for the SSH connection, required if `password` is not specified +* `known_hosts` - Optional SSH known hosts file to use to verify hosts +* `host_key_algorithms` - Optional list of SSH host key algorithms to use + * See constants beginning with `KeyAlgo*` in [crypto/ssh](https://godoc.org/golang.org/x/crypto/ssh#pkg-constants) * `timeout` - Optional timeout of the SSH connection, session and optional command. * The default comes from the `--collector.ssh.default-timeout` flag. * `command` - Optional command to run. @@ -110,8 +122,15 @@ The following example assumes this exporter is running on the Prometheus server metrics_path: /ssh static_configs: - targets: - - ssh1.example.com - - ssh2.example.com + - host1.example.com:22 + - host2.example.com:22 + labels: + module: default + - targets: + - host3.example.com:22 + - host4.example.com:22 + labels: + module: verify relabel_configs: - source_labels: [__address__] target_label: __param_target @@ -119,6 +138,11 @@ The following example assumes this exporter is running on the Prometheus server target_label: instance - target_label: __address__ replacement: 127.0.0.1:9312 + - source_labels: [module] + target_label: __param_module + metric_relabel_configs: + - regex: "^(module)$" + action: labeldrop - job_name: ssh-metrics metrics_path: /metrics static_configs: diff --git a/collector/collector.go b/collector/collector.go index d64f0e3..29aa325 100644 --- a/collector/collector.go +++ b/collector/collector.go @@ -15,7 +15,9 @@ package collector import ( "bytes" + "encoding/base64" "io/ioutil" + "net" "regexp" "strings" "time" @@ -25,6 +27,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/treydock/ssh_exporter/config" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" ) const ( @@ -32,7 +35,7 @@ const ( ) type Metric struct { - Success bool + Success float64 FailureReason string } @@ -70,7 +73,7 @@ func (c *Collector) Collect(ch chan<- prometheus.Metric) { metric := c.collect() - ch <- prometheus.MustNewConstMetric(c.Success, prometheus.GaugeValue, boolToFloat64(metric.Success)) + ch <- prometheus.MustNewConstMetric(c.Success, prometheus.GaugeValue, metric.Success) for _, reason := range failureReasons { var value float64 if reason == metric.FailureReason { @@ -103,10 +106,11 @@ func (c *Collector) collect() Metric { } sshConfig := &ssh.ClientConfig{ - User: c.target.User, - Auth: []ssh.AuthMethod{auth}, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Timeout: time.Duration(c.target.Timeout) * time.Second, + User: c.target.User, + Auth: []ssh.AuthMethod{auth}, + HostKeyCallback: hostKeyCallback(&metric, c.target, c.logger), + HostKeyAlgorithms: c.target.HostKeyAlgorithms, + Timeout: time.Duration(c.target.Timeout) * time.Second, } connection, err := ssh.Dial("tcp", c.target.Host, sshConfig) if err != nil { @@ -170,7 +174,7 @@ func (c *Collector) collect() Metric { return metric } } - metric.Success = true + metric.Success = 1 return metric } @@ -186,10 +190,22 @@ func getPrivateKeyAuth(privatekey string) (ssh.AuthMethod, error) { return ssh.PublicKeys(key), nil } -func boolToFloat64(data bool) float64 { - if data { - return float64(1) - } else { - return float64(0) +func hostKeyCallback(metric *Metric, target *config.Target, logger log.Logger) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + var hostKeyCallback ssh.HostKeyCallback + var err error + if target.KnownHosts != "" { + publicKey := base64.StdEncoding.EncodeToString(key.Marshal()) + level.Debug(logger).Log("msg", "Verify SSH known hosts", "hostname", hostname, "remote", remote.String(), "key", publicKey) + hostKeyCallback, err = knownhosts.New(target.KnownHosts) + if err != nil { + metric.FailureReason = "error" + level.Error(logger).Log("msg", "Error creating hostkeycallback function", "err", err) + return err + } + } else { + hostKeyCallback = ssh.InsecureIgnoreHostKey() + } + return hostKeyCallback(hostname, remote, key) } } diff --git a/collector/collector_test.go b/collector/collector_test.go index 25aa570..d1b2f9d 100644 --- a/collector/collector_test.go +++ b/collector/collector_test.go @@ -14,6 +14,8 @@ package collector import ( + "crypto/rand" + "crypto/rsa" "fmt" "io" "io/ioutil" @@ -26,12 +28,16 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/treydock/ssh_exporter/config" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" ) const ( listen = 60022 ) +var knownHosts *os.File + func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { buffer, err := ioutil.ReadFile("testdata/id_rsa_test1.pub") if err != nil { @@ -69,6 +75,27 @@ func TestMain(m *testing.M) { PublicKeyHandler: publicKeyHandler, PasswordHandler: passwordHandler, } + hostKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + fmt.Printf("ERROR generating RSA host key: %s", err) + os.Exit(1) + } + signer, err := gossh.NewSignerFromKey(hostKey) + if err != nil { + fmt.Printf("ERROR generating host key signer: %s", err) + os.Exit(1) + } + s.AddHostKey(signer) + knownHosts, err = ioutil.TempFile("", "knowm_hosts") + if err != nil { + fmt.Printf("ERROR creating known hosts: %s", err) + os.Exit(1) + } + defer os.Remove(knownHosts.Name()) + knownHostsLine := knownhosts.Line([]string{fmt.Sprintf("localhost:%d", listen)}, s.HostSigners[0].PublicKey()) + if _, err = knownHosts.Write([]byte(knownHostsLine)); err != nil { + fmt.Printf("ERROR writing known hosts: %s", err) + } go func() { if err := s.ListenAndServe(); err != nil { fmt.Printf("ERROR starting SSH server: %s", err) @@ -279,6 +306,108 @@ func TestCollectorPrivateKey(t *testing.T) { } } +func TestCollectorKnownHosts(t *testing.T) { + expected := ` + # HELP ssh_failure Indicates a failure + # TYPE ssh_failure gauge + ssh_failure{reason="command-error"} 0 + ssh_failure{reason="command-output"} 0 + ssh_failure{reason="error"} 0 + ssh_failure{reason="timeout"} 0 + # HELP ssh_success SSH connection was successful + # TYPE ssh_success gauge + ssh_success 1 + ` + target := &config.Target{ + Host: fmt.Sprintf("localhost:%d", listen), + User: "test", + PrivateKey: "testdata/id_rsa_test1", + KnownHosts: knownHosts.Name(), + Timeout: 2, + } + w := log.NewSyncWriter(os.Stderr) + logger := log.NewLogfmtLogger(w) + collector := NewCollector(target, logger) + gatherers := setupGatherer(collector) + if val, err := testutil.GatherAndCount(gatherers); err != nil { + t.Errorf("Unexpected error: %v", err) + } else if val != 6 { + t.Errorf("Unexpected collection count %d, expected 6", val) + } + if err := testutil.GatherAndCompare(gatherers, strings.NewReader(expected), + "ssh_success", "ssh_failure"); err != nil { + t.Errorf("unexpected collecting result:\n%s", err) + } +} + +func TestCollectorKnownHostsError(t *testing.T) { + expected := ` + # HELP ssh_failure Indicates a failure + # TYPE ssh_failure gauge + ssh_failure{reason="command-error"} 0 + ssh_failure{reason="command-output"} 0 + ssh_failure{reason="error"} 1 + ssh_failure{reason="timeout"} 0 + # HELP ssh_success SSH connection was successful + # TYPE ssh_success gauge + ssh_success 0 + ` + target := &config.Target{ + Host: fmt.Sprintf("127.0.0.1:%d", listen), + User: "test", + PrivateKey: "testdata/id_rsa_test1", + KnownHosts: knownHosts.Name(), + Timeout: 2, + } + w := log.NewSyncWriter(os.Stderr) + logger := log.NewLogfmtLogger(w) + collector := NewCollector(target, logger) + gatherers := setupGatherer(collector) + if val, err := testutil.GatherAndCount(gatherers); err != nil { + t.Errorf("Unexpected error: %v", err) + } else if val != 6 { + t.Errorf("Unexpected collection count %d, expected 6", val) + } + if err := testutil.GatherAndCompare(gatherers, strings.NewReader(expected), + "ssh_success", "ssh_failure"); err != nil { + t.Errorf("unexpected collecting result:\n%s", err) + } +} + +func TestCollectorKnownHostsDNE(t *testing.T) { + expected := ` + # HELP ssh_failure Indicates a failure + # TYPE ssh_failure gauge + ssh_failure{reason="command-error"} 0 + ssh_failure{reason="command-output"} 0 + ssh_failure{reason="error"} 1 + ssh_failure{reason="timeout"} 0 + # HELP ssh_success SSH connection was successful + # TYPE ssh_success gauge + ssh_success 0 + ` + target := &config.Target{ + Host: fmt.Sprintf("localhost:%d", listen), + User: "test", + PrivateKey: "testdata/id_rsa_test1", + KnownHosts: "/dne", + Timeout: 2, + } + w := log.NewSyncWriter(os.Stderr) + logger := log.NewLogfmtLogger(w) + collector := NewCollector(target, logger) + gatherers := setupGatherer(collector) + if val, err := testutil.GatherAndCount(gatherers); err != nil { + t.Errorf("Unexpected error: %v", err) + } else if val != 6 { + t.Errorf("Unexpected collection count %d, expected 6", val) + } + if err := testutil.GatherAndCompare(gatherers, strings.NewReader(expected), + "ssh_success", "ssh_failure"); err != nil { + t.Errorf("unexpected collecting result:\n%s", err) + } +} + func TestCollectDNEKey(t *testing.T) { target := &config.Target{ Host: fmt.Sprintf("localhost:%d", listen), diff --git a/config/config.go b/config/config.go index ef0af9b..03a7564 100644 --- a/config/config.go +++ b/config/config.go @@ -31,23 +31,27 @@ type SafeConfig struct { } type Module struct { - ModuleName string - User string `yaml:"user"` - Password string `yaml:"password"` - PrivateKey string `yaml:"private_key"` - Timeout int `yaml:"timeout"` - Command string `yaml:"command"` - CommandExpect string `yaml:"command_expect"` + ModuleName string + User string `yaml:"user"` + Password string `yaml:"password"` + PrivateKey string `yaml:"private_key"` + KnownHosts string `yaml:"known_hosts"` + HostKeyAlgorithms []string `yaml:"host_key_algorithms"` + Timeout int `yaml:"timeout"` + Command string `yaml:"command"` + CommandExpect string `yaml:"command_expect"` } type Target struct { - Host string - User string - Password string - PrivateKey string - Timeout int - Command string - CommandExpect string + Host string + User string + Password string + PrivateKey string + KnownHosts string + HostKeyAlgorithms []string + Timeout int + Command string + CommandExpect string } func (sc *SafeConfig) ReloadConfig(configFile string) error { diff --git a/ssh_exporter.go b/ssh_exporter.go index 0ee25ed..6ae4747 100644 --- a/ssh_exporter.go +++ b/ssh_exporter.go @@ -30,6 +30,11 @@ import ( "gopkg.in/alecthomas/kingpin.v2" ) +const ( + sshEndpoint = "/ssh" + metricsEndpoint = "/metrics" +) + var ( configFile = kingpin.Flag("config.file", "Path to exporter config file").Default("ssh_exporter.yaml").String() defaultTimeout = kingpin.Flag("collector.ssh.default-timeout", "Default timeout for SSH collection").Default("10").Int() @@ -59,13 +64,15 @@ func metricsHandler(c *config.Config, logger log.Logger) http.HandlerFunc { } target := &config.Target{ - Host: t, - User: module.User, - Password: module.Password, - PrivateKey: module.PrivateKey, - Timeout: module.Timeout, - Command: module.Command, - CommandExpect: module.CommandExpect, + Host: t, + User: module.User, + Password: module.Password, + PrivateKey: module.PrivateKey, + KnownHosts: module.KnownHosts, + HostKeyAlgorithms: module.HostKeyAlgorithms, + Timeout: module.Timeout, + Command: module.Command, + CommandExpect: module.CommandExpect, } sshCollector := collector.NewCollector(target, log.With(logger, "target", target.Host)) registry.MustRegister(sshCollector) @@ -79,7 +86,6 @@ func metricsHandler(c *config.Config, logger log.Logger) http.HandlerFunc { } func main() { - metricsEndpoint := "/ssh" promlogConfig := &promlog.Config{} flag.AddFlags(kingpin.CommandLine, promlogConfig) kingpin.Version(version.Print("ssh_exporter")) @@ -104,13 +110,13 @@ func main() { SSH Exporter

SSH Exporter

-

SSH Metrics

-

Exporter Metrics

+

SSH Metrics

+

Exporter Metrics

`)) }) - http.Handle(metricsEndpoint, metricsHandler(sc.C, logger)) - http.Handle("/metrics", promhttp.Handler()) + http.Handle(sshEndpoint, metricsHandler(sc.C, logger)) + http.Handle(metricsEndpoint, promhttp.Handler()) err := http.ListenAndServe(*listenAddress, nil) if err != nil { level.Error(logger).Log("err", err)