-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathknownhost.go
141 lines (121 loc) · 3.08 KB
/
knownhost.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package knownhost
import (
"bufio"
"context"
"golang.org/x/crypto/ssh"
"net"
"os"
"path/filepath"
"strings"
"time"
)
type KnownHost struct {
knownHostFile string
}
func NewKnownHost(opts ...Option) *KnownHost {
knownHost := &KnownHost{}
for _, opt := range opts {
opt(knownHost)
}
return knownHost
}
type Option func(host *KnownHost)
func WithDefaultKnownHostsFile(yes bool) Option {
return func(host *KnownHost) {
if yes {
host.knownHostFile, _ = host.GetDefaultKnownHostFile()
}
}
}
func WithCustomFile(fileName string) Option {
return func(host *KnownHost) {
host.knownHostFile = fileName
}
}
// GetDefaultKnownHostFile returns default knownhosts file path
func (k *KnownHost) GetDefaultKnownHostFile() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(homeDir, ".ssh", "known_hosts"), nil
}
// ReadLocalHostKeyForHost read known host key for specify host
func (k *KnownHost) ReadLocalHostKeyForHost(host string) (hostKey ssh.PublicKey, err error) {
file, err := os.Open(k.knownHostFile)
if err != nil {
return nil, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
fields := strings.Split(scanner.Text(), " ")
if len(fields) != 3 {
continue
}
if strings.Contains(fields[0], host) {
var err error
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
if err != nil {
return nil, err
}
}
}
return hostKey, nil
}
func (k *KnownHost) GetKeysForHost(host string, timeout time.Duration) ([]ssh.PublicKey, error) {
var (
publicKeys = make([]ssh.PublicKey, 0, len(supportedHostKeyAlgorithms))
recv = make(chan ssh.PublicKey, 1)
)
ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
ctx := context.WithValue(ctxTimeout, timeoutKey, timeout)
ctx = context.WithValue(ctx, hostKey, host)
for _, algorithm := range supportedHostKeyAlgorithms {
ctx = context.WithValue(ctx, algorithmKey, algorithm)
go processFilterData(ctx, recv)
}
for {
select {
case <-ctxTimeout.Done():
return publicKeys, nil
case pubKey := <-recv:
publicKeys = append(publicKeys, pubKey)
continue
}
}
return publicKeys, nil
}
func processFilterData(ctx context.Context, recv chan ssh.PublicKey) {
publicKey := getPublicKey(ctx)
if publicKey != nil {
recv <- publicKey
}
}
func getPublicKey(ctx context.Context) (key ssh.PublicKey) {
timeout := getTimeoutFromContext(ctx)
host := getHostFromContext(ctx)
algorithm := getAlgorithmFromContext(ctx)
d := net.Dialer{Timeout: timeout}
conn, err := d.Dial("tcp", host)
if err != nil {
return key
}
defer conn.Close()
config := ssh.ClientConfig{
HostKeyAlgorithms: []string{algorithm},
HostKeyCallback: hostKeyCallback(&key),
}
sshConn, _, _, err := ssh.NewClientConn(conn, host, &config)
if err == nil {
sshConn.Close()
}
return key
}
func hostKeyCallback(publicKey *ssh.PublicKey) func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
*publicKey = key
return nil
}
}