Skip to content

Commit ab60683

Browse files
authored
Support both ECDSA + RSA for SSH Server (#65)
1 parent 03c048d commit ab60683

File tree

1 file changed

+205
-20
lines changed

1 file changed

+205
-20
lines changed

remote_shell_service.go

Lines changed: 205 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package main
22

33
import (
44
"bufio"
5+
"crypto/ecdsa"
6+
"crypto/elliptic"
57
"crypto/rand"
68
"crypto/rsa"
79
"crypto/subtle"
@@ -17,11 +19,17 @@ import (
1719
"github.com/gliderlabs/ssh"
1820
"github.com/google/uuid"
1921
"go.uber.org/zap"
22+
gossh "golang.org/x/crypto/ssh"
2023
"golang.org/x/term"
2124
)
2225

2326
type ConsoleTarget int32
2427

28+
const (
29+
RSAKeyType string = "RSA PRIVATE KEY"
30+
ECKeyType = "EC PRIVATE KEY"
31+
)
32+
2533
const (
2634
stdOutTarget ConsoleTarget = 0
2735
stdErrTarget ConsoleTarget = 1
@@ -194,55 +202,232 @@ func consoleInRoutine(stdIn io.Reader, console *Console, logger *zap.Logger) {
194202
}
195203
}
196204

197-
func ensureHostKey(logger *zap.Logger) (string, error) {
205+
const (
206+
// Current filename, hides on Linux systems.
207+
HostKeyFilename string = ".hostKey.pem"
208+
209+
// Old filename, not hidden.
210+
OldHostKeyFilename = "hostKey.pem"
211+
)
212+
213+
// Use the hidden form first, but fallback to the non-hidden one if it already exists.
214+
func pickHostKeyPath(homeDir string) string {
215+
defaultKeyfilePath := filepath.Join(homeDir, HostKeyFilename)
216+
_, err := os.Stat(defaultKeyfilePath)
217+
if !os.IsNotExist(err) {
218+
return defaultKeyfilePath
219+
}
220+
221+
fallbackKeyfilePath := filepath.Join(homeDir, OldHostKeyFilename)
222+
_, err = os.Stat(fallbackKeyfilePath)
223+
if !os.IsNotExist(err) {
224+
return fallbackKeyfilePath
225+
}
226+
227+
return defaultKeyfilePath
228+
}
229+
230+
// Exists to clean up the non-hidden key file if it still exists
231+
func cleanupOldHostKey() error {
198232
homeDir, err := os.UserHomeDir()
199233
if err != nil {
200-
return "", err
234+
return err
201235
}
202236

203-
keyfilePath := filepath.Join(homeDir, "hostKey.pem")
237+
keyfilePath := filepath.Join(homeDir, OldHostKeyFilename)
204238
_, err = os.Stat(keyfilePath)
205239
if os.IsNotExist(err) {
206-
logger.Info("Generating host key for remote shell server.")
207-
hostKey, err := rsa.GenerateKey(rand.Reader, 4096)
240+
return nil
241+
}
242+
243+
err = os.Remove(keyfilePath)
244+
if err != nil {
245+
return err
246+
}
247+
248+
_, err = os.Stat(keyfilePath)
249+
if !os.IsNotExist(err) {
250+
return err
251+
}
252+
253+
return nil
254+
}
255+
256+
type hostKeys struct {
257+
rsaKey *rsa.PrivateKey
258+
ecKey *ecdsa.PrivateKey
259+
}
260+
261+
func populateKeys(keys *hostKeys, logger *zap.Logger) (bool, error) {
262+
didAdd := false
263+
if keys.ecKey == nil {
264+
logger.Info("Generating ECDSA SSH Host Key")
265+
ellipticKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
208266
if err != nil {
209-
return keyfilePath, err
267+
return didAdd, err
210268
}
211269

212-
err = hostKey.Validate()
270+
keys.ecKey = ellipticKey
271+
didAdd = true
272+
}
273+
274+
if keys.rsaKey == nil {
275+
logger.Info("Generating RSA SSH Host Key")
276+
rsaKey, err := rsa.GenerateKey(rand.Reader, 4096)
213277
if err != nil {
214-
return keyfilePath, err
278+
return didAdd, err
215279
}
216280

217-
hostDER := x509.MarshalPKCS1PrivateKey(hostKey)
218-
hostBlock := pem.Block{
219-
Type: "RSA PRIVATE KEY",
220-
Headers: nil,
221-
Bytes: hostDER,
281+
keys.rsaKey = rsaKey
282+
didAdd = true
283+
}
284+
285+
return didAdd, nil
286+
}
287+
288+
func writeKeys(hostKeyPath string, keys *hostKeys, logger *zap.Logger) error {
289+
keysFile, err := os.OpenFile(hostKeyPath, os.O_CREATE+os.O_WRONLY+os.O_TRUNC, 0600)
290+
if err != nil {
291+
return err
292+
}
293+
294+
defer keysFile.Close()
295+
296+
logger.Info(fmt.Sprintf("Writing Host Keys to %s.", hostKeyPath))
297+
if keys.ecKey != nil {
298+
ecDER, err := x509.MarshalECPrivateKey(keys.ecKey)
299+
if err != nil {
300+
return err
301+
}
302+
303+
ecBlock := pem.Block{
304+
Type: ECKeyType,
305+
Bytes: ecDER,
222306
}
223-
hostPEM := pem.EncodeToMemory(&hostBlock)
224307

225-
err = os.WriteFile(keyfilePath, hostPEM, 0600)
226-
return keyfilePath, err
308+
pem.Encode(keysFile, &ecBlock)
227309
}
228310

229-
return keyfilePath, err
311+
if keys.rsaKey != nil {
312+
rsaDER := x509.MarshalPKCS1PrivateKey(keys.rsaKey)
313+
rsaBlock := pem.Block{
314+
Type: RSAKeyType,
315+
Bytes: rsaDER,
316+
}
317+
318+
pem.Encode(keysFile, &rsaBlock)
319+
}
320+
321+
return nil
322+
}
323+
324+
func readKeys(hostKeyPath string) (*hostKeys, error) {
325+
bytes, err := os.ReadFile(hostKeyPath)
326+
if err != nil {
327+
return nil, err
328+
}
329+
330+
var keys hostKeys
331+
for len(bytes) > 0 {
332+
pemBlock, next := pem.Decode(bytes)
333+
if pemBlock == nil {
334+
break
335+
}
336+
337+
switch pemBlock.Type {
338+
case RSAKeyType:
339+
rsaKey, err := x509.ParsePKCS1PrivateKey(pemBlock.Bytes)
340+
if err != nil {
341+
return &keys, err
342+
}
343+
keys.rsaKey = rsaKey
344+
case ECKeyType:
345+
ecKey, err := x509.ParseECPrivateKey(pemBlock.Bytes)
346+
if err != nil {
347+
return &keys, err
348+
}
349+
keys.ecKey = ecKey
350+
}
351+
352+
bytes = next
353+
}
354+
355+
return &keys, nil
356+
}
357+
358+
func ensureHostKeys(logger *zap.Logger) (*hostKeys, error) {
359+
homeDir, err := os.UserHomeDir()
360+
if err != nil {
361+
return nil, err
362+
}
363+
364+
keyfilePath := pickHostKeyPath(homeDir)
365+
defaultKeyfilePath := filepath.Join(homeDir, HostKeyFilename)
366+
fileChanged := keyfilePath != defaultKeyfilePath
367+
_, err = os.Stat(keyfilePath)
368+
if os.IsNotExist(err) {
369+
logger.Info("Generating host keys for remote shell server.")
370+
var hostKeys hostKeys
371+
addedKeys, err := populateKeys(&hostKeys, logger)
372+
373+
if (fileChanged || addedKeys) && err == nil {
374+
writeKeys(defaultKeyfilePath, &hostKeys, logger)
375+
}
376+
return &hostKeys, err
377+
} else {
378+
logger.Info(fmt.Sprintf("Reading host keys for remote shell from %s.", keyfilePath))
379+
hostKeys, err := readKeys(keyfilePath)
380+
if err != nil {
381+
return nil, err
382+
}
383+
384+
// Populate missing keys (older files only have RSA)
385+
addedKeys, err := populateKeys(hostKeys, logger)
386+
387+
if (fileChanged || addedKeys) && err == nil {
388+
writeKeys(defaultKeyfilePath, hostKeys, logger)
389+
}
390+
return hostKeys, err
391+
}
392+
}
393+
394+
func twinKeys(keys *hostKeys) ssh.Option {
395+
return func(srv *ssh.Server) error {
396+
rsaSigner, err := gossh.NewSignerFromKey(keys.rsaKey)
397+
if err != nil {
398+
return err
399+
}
400+
srv.AddHostKey(rsaSigner)
401+
402+
ecSigner, err := gossh.NewSignerFromKey(keys.ecKey)
403+
if err != nil {
404+
return err
405+
}
406+
srv.AddHostKey(ecSigner)
407+
408+
return nil
409+
}
230410
}
231411

232412
func runRemoteShellServer(console *Console, logger *zap.Logger) {
233413
logger.Info("Starting remote shell server on 2222...")
234414
ssh.Handle(func(s ssh.Session) { handleSession(s, console, logger) })
235415

236-
hostKeyPath, err := ensureHostKey(logger)
416+
hostKeys, err := ensureHostKeys(logger)
237417
if err != nil {
238-
logger.Error("Unable to ensure host key exists", zap.Error(err))
418+
logger.Error("Unable to ensure host keys exist", zap.Error(err))
239419
return
240420
}
241421

422+
err = cleanupOldHostKey()
423+
if err != nil {
424+
logger.Warn("Unable to remote old host key file", zap.Error(err))
425+
}
426+
242427
log.Fatal(ssh.ListenAndServe(
243428
":2222",
244429
nil,
245-
ssh.HostKeyFile(hostKeyPath),
430+
twinKeys(hostKeys),
246431
ssh.PasswordAuth(func(ctx ssh.Context, password string) bool { return passwordHandler(ctx, password, logger) }),
247432
))
248433
}

0 commit comments

Comments
 (0)