@@ -2,6 +2,8 @@ package main
2
2
3
3
import (
4
4
"bufio"
5
+ "crypto/ecdsa"
6
+ "crypto/elliptic"
5
7
"crypto/rand"
6
8
"crypto/rsa"
7
9
"crypto/subtle"
@@ -17,11 +19,17 @@ import (
17
19
"github.com/gliderlabs/ssh"
18
20
"github.com/google/uuid"
19
21
"go.uber.org/zap"
22
+ gossh "golang.org/x/crypto/ssh"
20
23
"golang.org/x/term"
21
24
)
22
25
23
26
type ConsoleTarget int32
24
27
28
+ const (
29
+ RSAKeyType string = "RSA PRIVATE KEY"
30
+ ECKeyType = "EC PRIVATE KEY"
31
+ )
32
+
25
33
const (
26
34
stdOutTarget ConsoleTarget = 0
27
35
stdErrTarget ConsoleTarget = 1
@@ -194,55 +202,232 @@ func consoleInRoutine(stdIn io.Reader, console *Console, logger *zap.Logger) {
194
202
}
195
203
}
196
204
197
- func ensureHostKey (logger * zap.Logger ) (string , error ) {
205
+ const (
206
+ // Current filename, hides on Linux systems.
207
+ HostKeyFilename string = ".hostKey.pem"
208
+
209
+ // Old filename, not hidden.
210
+ OldHostKeyFilename = "hostKey.pem"
211
+ )
212
+
213
+ // Use the hidden form first, but fallback to the non-hidden one if it already exists.
214
+ func pickHostKeyPath (homeDir string ) string {
215
+ defaultKeyfilePath := filepath .Join (homeDir , HostKeyFilename )
216
+ _ , err := os .Stat (defaultKeyfilePath )
217
+ if ! os .IsNotExist (err ) {
218
+ return defaultKeyfilePath
219
+ }
220
+
221
+ fallbackKeyfilePath := filepath .Join (homeDir , OldHostKeyFilename )
222
+ _ , err = os .Stat (fallbackKeyfilePath )
223
+ if ! os .IsNotExist (err ) {
224
+ return fallbackKeyfilePath
225
+ }
226
+
227
+ return defaultKeyfilePath
228
+ }
229
+
230
+ // Exists to clean up the non-hidden key file if it still exists
231
+ func cleanupOldHostKey () error {
198
232
homeDir , err := os .UserHomeDir ()
199
233
if err != nil {
200
- return "" , err
234
+ return err
201
235
}
202
236
203
- keyfilePath := filepath .Join (homeDir , "hostKey.pem" )
237
+ keyfilePath := filepath .Join (homeDir , OldHostKeyFilename )
204
238
_ , err = os .Stat (keyfilePath )
205
239
if os .IsNotExist (err ) {
206
- logger .Info ("Generating host key for remote shell server." )
207
- hostKey , err := rsa .GenerateKey (rand .Reader , 4096 )
240
+ return nil
241
+ }
242
+
243
+ err = os .Remove (keyfilePath )
244
+ if err != nil {
245
+ return err
246
+ }
247
+
248
+ _ , err = os .Stat (keyfilePath )
249
+ if ! os .IsNotExist (err ) {
250
+ return err
251
+ }
252
+
253
+ return nil
254
+ }
255
+
256
+ type hostKeys struct {
257
+ rsaKey * rsa.PrivateKey
258
+ ecKey * ecdsa.PrivateKey
259
+ }
260
+
261
+ func populateKeys (keys * hostKeys , logger * zap.Logger ) (bool , error ) {
262
+ didAdd := false
263
+ if keys .ecKey == nil {
264
+ logger .Info ("Generating ECDSA SSH Host Key" )
265
+ ellipticKey , err := ecdsa .GenerateKey (elliptic .P384 (), rand .Reader )
208
266
if err != nil {
209
- return keyfilePath , err
267
+ return didAdd , err
210
268
}
211
269
212
- err = hostKey .Validate ()
270
+ keys .ecKey = ellipticKey
271
+ didAdd = true
272
+ }
273
+
274
+ if keys .rsaKey == nil {
275
+ logger .Info ("Generating RSA SSH Host Key" )
276
+ rsaKey , err := rsa .GenerateKey (rand .Reader , 4096 )
213
277
if err != nil {
214
- return keyfilePath , err
278
+ return didAdd , err
215
279
}
216
280
217
- hostDER := x509 .MarshalPKCS1PrivateKey (hostKey )
218
- hostBlock := pem.Block {
219
- Type : "RSA PRIVATE KEY" ,
220
- Headers : nil ,
221
- Bytes : hostDER ,
281
+ keys .rsaKey = rsaKey
282
+ didAdd = true
283
+ }
284
+
285
+ return didAdd , nil
286
+ }
287
+
288
+ func writeKeys (hostKeyPath string , keys * hostKeys , logger * zap.Logger ) error {
289
+ keysFile , err := os .OpenFile (hostKeyPath , os .O_CREATE + os .O_WRONLY + os .O_TRUNC , 0600 )
290
+ if err != nil {
291
+ return err
292
+ }
293
+
294
+ defer keysFile .Close ()
295
+
296
+ logger .Info (fmt .Sprintf ("Writing Host Keys to %s." , hostKeyPath ))
297
+ if keys .ecKey != nil {
298
+ ecDER , err := x509 .MarshalECPrivateKey (keys .ecKey )
299
+ if err != nil {
300
+ return err
301
+ }
302
+
303
+ ecBlock := pem.Block {
304
+ Type : ECKeyType ,
305
+ Bytes : ecDER ,
222
306
}
223
- hostPEM := pem .EncodeToMemory (& hostBlock )
224
307
225
- err = os .WriteFile (keyfilePath , hostPEM , 0600 )
226
- return keyfilePath , err
308
+ pem .Encode (keysFile , & ecBlock )
227
309
}
228
310
229
- return keyfilePath , err
311
+ if keys .rsaKey != nil {
312
+ rsaDER := x509 .MarshalPKCS1PrivateKey (keys .rsaKey )
313
+ rsaBlock := pem.Block {
314
+ Type : RSAKeyType ,
315
+ Bytes : rsaDER ,
316
+ }
317
+
318
+ pem .Encode (keysFile , & rsaBlock )
319
+ }
320
+
321
+ return nil
322
+ }
323
+
324
+ func readKeys (hostKeyPath string ) (* hostKeys , error ) {
325
+ bytes , err := os .ReadFile (hostKeyPath )
326
+ if err != nil {
327
+ return nil , err
328
+ }
329
+
330
+ var keys hostKeys
331
+ for len (bytes ) > 0 {
332
+ pemBlock , next := pem .Decode (bytes )
333
+ if pemBlock == nil {
334
+ break
335
+ }
336
+
337
+ switch pemBlock .Type {
338
+ case RSAKeyType :
339
+ rsaKey , err := x509 .ParsePKCS1PrivateKey (pemBlock .Bytes )
340
+ if err != nil {
341
+ return & keys , err
342
+ }
343
+ keys .rsaKey = rsaKey
344
+ case ECKeyType :
345
+ ecKey , err := x509 .ParseECPrivateKey (pemBlock .Bytes )
346
+ if err != nil {
347
+ return & keys , err
348
+ }
349
+ keys .ecKey = ecKey
350
+ }
351
+
352
+ bytes = next
353
+ }
354
+
355
+ return & keys , nil
356
+ }
357
+
358
+ func ensureHostKeys (logger * zap.Logger ) (* hostKeys , error ) {
359
+ homeDir , err := os .UserHomeDir ()
360
+ if err != nil {
361
+ return nil , err
362
+ }
363
+
364
+ keyfilePath := pickHostKeyPath (homeDir )
365
+ defaultKeyfilePath := filepath .Join (homeDir , HostKeyFilename )
366
+ fileChanged := keyfilePath != defaultKeyfilePath
367
+ _ , err = os .Stat (keyfilePath )
368
+ if os .IsNotExist (err ) {
369
+ logger .Info ("Generating host keys for remote shell server." )
370
+ var hostKeys hostKeys
371
+ addedKeys , err := populateKeys (& hostKeys , logger )
372
+
373
+ if (fileChanged || addedKeys ) && err == nil {
374
+ writeKeys (defaultKeyfilePath , & hostKeys , logger )
375
+ }
376
+ return & hostKeys , err
377
+ } else {
378
+ logger .Info (fmt .Sprintf ("Reading host keys for remote shell from %s." , keyfilePath ))
379
+ hostKeys , err := readKeys (keyfilePath )
380
+ if err != nil {
381
+ return nil , err
382
+ }
383
+
384
+ // Populate missing keys (older files only have RSA)
385
+ addedKeys , err := populateKeys (hostKeys , logger )
386
+
387
+ if (fileChanged || addedKeys ) && err == nil {
388
+ writeKeys (defaultKeyfilePath , hostKeys , logger )
389
+ }
390
+ return hostKeys , err
391
+ }
392
+ }
393
+
394
+ func twinKeys (keys * hostKeys ) ssh.Option {
395
+ return func (srv * ssh.Server ) error {
396
+ rsaSigner , err := gossh .NewSignerFromKey (keys .rsaKey )
397
+ if err != nil {
398
+ return err
399
+ }
400
+ srv .AddHostKey (rsaSigner )
401
+
402
+ ecSigner , err := gossh .NewSignerFromKey (keys .ecKey )
403
+ if err != nil {
404
+ return err
405
+ }
406
+ srv .AddHostKey (ecSigner )
407
+
408
+ return nil
409
+ }
230
410
}
231
411
232
412
func runRemoteShellServer (console * Console , logger * zap.Logger ) {
233
413
logger .Info ("Starting remote shell server on 2222..." )
234
414
ssh .Handle (func (s ssh.Session ) { handleSession (s , console , logger ) })
235
415
236
- hostKeyPath , err := ensureHostKey (logger )
416
+ hostKeys , err := ensureHostKeys (logger )
237
417
if err != nil {
238
- logger .Error ("Unable to ensure host key exists " , zap .Error (err ))
418
+ logger .Error ("Unable to ensure host keys exist " , zap .Error (err ))
239
419
return
240
420
}
241
421
422
+ err = cleanupOldHostKey ()
423
+ if err != nil {
424
+ logger .Warn ("Unable to remote old host key file" , zap .Error (err ))
425
+ }
426
+
242
427
log .Fatal (ssh .ListenAndServe (
243
428
":2222" ,
244
429
nil ,
245
- ssh . HostKeyFile ( hostKeyPath ),
430
+ twinKeys ( hostKeys ),
246
431
ssh .PasswordAuth (func (ctx ssh.Context , password string ) bool { return passwordHandler (ctx , password , logger ) }),
247
432
))
248
433
}
0 commit comments