diff --git a/README.md b/README.md index 317feb5..57d517b 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,15 @@ func main() { } } + // init an empty public key with version 3 + outKey := putty.Key{Version: 3} + + // set the private key + outKey.SetKey(privateKey) + + // print out the ppk file + fmt.Printf("%s\n", outKey.Marshal()) + log.Printf("%+#v", privateKey) } ``` diff --git a/dsa.go b/dsa.go index 48b10f1..fdad894 100644 --- a/dsa.go +++ b/dsa.go @@ -14,7 +14,7 @@ func (k Key) readDSAPublicKey() (*dsa.PublicKey, error) { G *big.Int Pub *big.Int } - err := unmarshal(k.PublicKey, &pub, false) + _, err := unmarshal(k.PublicKey, &pub, false) if err != nil { return nil, err } @@ -35,14 +35,32 @@ func (k Key) readDSAPublicKey() (*dsa.PublicKey, error) { return publicKey, nil } -func (k Key) readDSAPrivateKey() (*dsa.PrivateKey, error) { +func (k *Key) setDSAPublicKey(toSet *dsa.PublicKey) (err error) { + var pub struct { + Header string + P *big.Int + Q *big.Int + G *big.Int + Pub *big.Int + } + k.Algo = "ssh-dss" + pub.Header = k.Algo + pub.P = toSet.Parameters.P + pub.Q = toSet.Parameters.Q + pub.G = toSet.Parameters.G + pub.Pub = toSet.Y + k.PublicKey, err = marshal(&pub) + return +} + +func (k *Key) readDSAPrivateKey() (*dsa.PrivateKey, error) { publicKey, err := k.readDSAPublicKey() if err != nil { return nil, err } var priv *big.Int - err = unmarshal(k.PrivateKey, &priv, k.Encryption != "none") + k.keySize, err = unmarshal(k.PrivateKey, &priv, k.padded) if err != nil { return nil, err } @@ -54,3 +72,17 @@ func (k Key) readDSAPrivateKey() (*dsa.PrivateKey, error) { return privateKey, nil } + +func (k *Key) setDSAPrivateKey(pk *dsa.PrivateKey) (err error) { + err = k.setDSAPublicKey(&pk.PublicKey) + if err != nil { + return err + } + + var priv *big.Int + priv = pk.X + k.PrivateKey, err = marshal(&priv) + k.keySize = len(k.PrivateKey) + k.padded = false + return +} diff --git a/ecdsa.go b/ecdsa.go index 57317c8..6535323 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -14,7 +14,7 @@ func (k Key) readECDSAPublicKey() (*ecdsa.PublicKey, error) { Length string Bytes []byte } - err := unmarshal(k.PublicKey, &pub, false) + _, err := unmarshal(k.PublicKey, &pub, false) if err != nil { return nil, err } @@ -52,14 +52,52 @@ func (k Key) readECDSAPublicKey() (*ecdsa.PublicKey, error) { return publicKey, nil } -func (k Key) readECDSAPrivateKey() (*ecdsa.PrivateKey, error) { +func (k *Key) setECDSAPublicKey(pk *ecdsa.PublicKey) (err error) { + var pub struct { + Header string + Length string + Bytes []byte + } + + switch c := pk.Curve.Params().Name; c { + case "P-256": + pub.Length = "nistp256" + case "P-384": + pub.Length = "nistp384" + case "P-521": + pub.Length = "nistp521" + default: + return fmt.Errorf("unsupported elliptic curve %s", c) + } + + pub.Header = "ecdsa-sha2-" + pub.Length + k.Algo = pub.Header + + x := (pk.X).Bytes() + y := (pk.Y).Bytes() + + // balance the integers slices + if diff := len(x) - len(y); diff > 0 { + y = append(make([]byte, diff), y...) + } else if diff < 0 { + x = append(make([]byte, -diff), x...) + } + + pub.Bytes = append([]byte{4}, x...) + pub.Bytes = append(pub.Bytes, y...) + + k.PublicKey, err = marshal(&pub) + return +} + +func (k *Key) readECDSAPrivateKey() (*ecdsa.PrivateKey, error) { publicKey, err := k.readECDSAPublicKey() if err != nil { return nil, err } var priv *big.Int - err = unmarshal(k.PrivateKey, &priv, k.Encryption != "none") + k.keySize, err = unmarshal(k.PrivateKey, &priv, k.padded) if err != nil { return nil, err } @@ -87,3 +125,14 @@ func (k Key) readECDSAPrivateKey() (*ecdsa.PrivateKey, error) { return privateKey, nil } + +func (k *Key) setECDSAPrivateKey(pk *ecdsa.PrivateKey) (err error) { + err = k.setECDSAPublicKey(&pk.PublicKey) + if err != nil { + return + } + k.PrivateKey, err = marshal(&pk.D) + k.keySize = len(k.PrivateKey) + k.padded = false + return +} diff --git a/ed25519.go b/ed25519.go index 62e25d8..7201c1d 100644 --- a/ed25519.go +++ b/ed25519.go @@ -10,7 +10,7 @@ func (k Key) readED25519PublicKey() (*ed25519.PublicKey, error) { Header string Bytes []byte } - err := unmarshal(k.PublicKey, &pub, false) + _, err := unmarshal(k.PublicKey, &pub, false) if err != nil { return nil, err } @@ -26,14 +26,26 @@ func (k Key) readED25519PublicKey() (*ed25519.PublicKey, error) { return (*ed25519.PublicKey)(&pub.Bytes), nil } -func (k Key) readED25519PrivateKey() (*ed25519.PrivateKey, error) { +func (k *Key) setED25519PublicKey(pk *ed25519.PublicKey) (err error) { + var pub struct { + Header string + Bytes []byte + } + k.Algo = "ssh-ed25519" + pub.Header = k.Algo + pub.Bytes = ([]byte)(*pk) + k.PublicKey, err = marshal(&pub) + return +} + +func (k *Key) readED25519PrivateKey() (*ed25519.PrivateKey, error) { publicKey, err := k.readED25519PublicKey() if err != nil { return nil, err } var priv []byte - err = unmarshal(k.PrivateKey, &priv, k.Encryption != "none") + k.keySize, err = unmarshal(k.PrivateKey, &priv, k.padded) if err != nil { return nil, err } @@ -48,3 +60,19 @@ func (k Key) readED25519PrivateKey() (*ed25519.PrivateKey, error) { return &privateKey, nil } + +func (k *Key) setED25519PrivateKey(pk *ed25519.PrivateKey) (err error) { + bytes := ([]byte)(*pk) + cut := ed25519.PrivateKeySize - ed25519.PublicKeySize + pub := bytes[cut:] + err = k.setED25519PublicKey((*ed25519.PublicKey)(&pub)) + if err != nil { + return err + } + + priv := bytes[:cut] + k.PrivateKey, err = marshal(&priv) + k.keySize = len(k.PrivateKey) + k.padded = false + return +} diff --git a/go.mod b/go.mod index 9b91c04..9114b12 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/kayrus/putty go 1.13 -require golang.org/x/crypto v0.1.0 +require ( + github.com/pschou/go-cbc3 v0.0.0-20221001011156-4e1428fbe50c // indirect + golang.org/x/crypto golang.org/x/crypto v0.1.0 +) \ No newline at end of file diff --git a/go.sum b/go.sum index 637640c..1619c8a 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,6 @@ +github.com/pschou/go-cbc3 v0.0.0-20220927161413-29a957308e2d/go.mod h1:b5sRQXy0ZdH7/nBD+FLWT4W+R4OIgNTw4rM9o/+XeKc= +github.com/pschou/go-cbc3 v0.0.0-20221001011156-4e1428fbe50c h1:dk8hPEU14qkDRAfX/Lzt0N63I/2iW+vYlZecK7Vkc+4= +github.com/pschou/go-cbc3 v0.0.0-20221001011156-4e1428fbe50c/go.mod h1:e5fPLORCNq9grPyodp74EPO3uxyrZUEjXWCErhGMlLI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/lib.go b/lib.go new file mode 100644 index 0000000..9e927fb --- /dev/null +++ b/lib.go @@ -0,0 +1,15 @@ +package putty + +func splitByWidth(str string, size int) []string { + strLength := len(str) + var splited []string + var stop int + for i := 0; i < strLength; i += size { + stop = i + size + if stop > strLength { + stop = strLength + } + splited = append(splited, str[i:stop]) + } + return splited +} diff --git a/marshal.go b/marshal.go new file mode 100644 index 0000000..9971516 --- /dev/null +++ b/marshal.go @@ -0,0 +1,126 @@ +package putty + +import ( + "bytes" + "crypto/aes" + "crypto/sha1" + "encoding/binary" + "fmt" + "math/big" + "reflect" +) + +func marshal(val interface{}) (data []byte, err error) { + v := reflect.ValueOf(val).Elem() + buf := bytes.NewBuffer([]byte{}) + + err = writeField(v, buf) + data = buf.Bytes() + return +} + +func addPadding(data []byte) []byte { + keySize := len(data) + if keySize%aes.BlockSize == 0 { + return data + } + sha := sha1.Sum(data) + padSize := aes.BlockSize - keySize&(aes.BlockSize-1) + data = append(data, make([]byte, padSize)...) + copy(data[keySize:], sha[:]) + return data +} + +func writeField(v reflect.Value, dst *bytes.Buffer) error { + fieldType := v.Type() + + switch fieldType.Kind() { + case reflect.Struct: + for i := 0; i < fieldType.NumField(); i++ { + if fieldType.Field(i).PkgPath != "" { + return fmt.Errorf("struct contains unexported fields") + } + + err := writeField(v.Field(i), dst) + if err != nil { + return err + } + } + return nil + } + + switch fieldType { + case reflect.TypeOf(string("")): + err := writeString2(dst, v.String()) + if err != nil { + return err + } + case reflect.TypeOf([]byte(nil)): + err := writeBytes2(dst, v.Bytes()) + if err != nil { + return err + } + case reflect.TypeOf(new(big.Int)): + switch val := (v.Interface()).(type) { + case *big.Int: + err := writeBigInt2(dst, val) + if err != nil { + return err + } + default: + return fmt.Errorf("unable to set big int") + } + default: + return fmt.Errorf("unknown type %s", fieldType) + } + + return nil +} + +func writeBytes1(dst *bytes.Buffer, data []byte, length uint16) error { + //length := uint16(len(data) * 8) + // write 4 bytes (data size) of the next element size + err := binary.Write(dst, binary.BigEndian, &length) + if err != nil { + return err + } + + _, err = dst.Write(data) + + return err +} + +func writeBigInt1(dst *bytes.Buffer, data *big.Int) error { + b := data.Bytes() + + return writeBytes1(dst, b, uint16(data.BitLen())) +} + +func writeInt32(dst *bytes.Buffer, val uint32) error { + return binary.Write(dst, binary.BigEndian, &val) +} + +func writeBytes2(dst *bytes.Buffer, data []byte) error { + length := uint32(len(data)) + // write 4 bytes (data size) of the next element size + err := binary.Write(dst, binary.BigEndian, &length) + if err != nil { + return err + } + + _, err = dst.Write(data) + + return err +} + +func writeString2(dst *bytes.Buffer, data string) error { + return writeBytes2(dst, []byte(data)) +} + +func writeBigInt2(dst *bytes.Buffer, data *big.Int) error { + b := data.Bytes() + for (len(b) > 8 && b[0] >= 128) || len(b) == 2 || len(b) == 3 { + b = append([]byte{0}, b...) + } + return writeBytes2(dst, b) +} diff --git a/putty.go b/putty.go index 393e309..06eb08c 100644 --- a/putty.go +++ b/putty.go @@ -5,7 +5,11 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" "crypto/hmac" + "crypto/rsa" "crypto/sha1" "crypto/sha256" "encoding/base64" @@ -16,6 +20,7 @@ import ( "io" "os" "strconv" + "strings" "golang.org/x/crypto/argon2" ) @@ -53,7 +58,102 @@ type Key struct { Comment string Encryption string PrivateMac []byte - decrypted bool + padded bool + keySize int +} + +func (k Key) Clone() *Key { + return &Key{ + Version: k.Version, + Algo: k.Algo, + PublicKey: append([]byte{}, k.PublicKey...), + PrivateKey: append([]byte{}, k.PrivateKey...), + KeyDerivation: k.KeyDerivation, + Argon2Memory: k.Argon2Memory, + Argon2Passes: k.Argon2Passes, + Argon2Parallelism: k.Argon2Parallelism, + Argon2Salt: append([]byte{}, k.Argon2Salt...), + Comment: k.Comment, + Encryption: k.Encryption, + PrivateMac: append([]byte{}, k.PrivateMac...), + padded: k.padded, + keySize: k.keySize, + } +} + +var noNewLines = strings.NewReplacer("\r", "", "\n", "") + +// Marshal returns the key in the raw ppk format for saving to a file. +func (k *Key) Marshal() (ret []byte, err error) { + // Helpful notes about the putty formats: + // https://tartarus.org/~simon/putty-snapshots/htmldoc/AppendixC.html + buf := new(bytes.Buffer) + switch k.Version { + case 1: + buf.WriteString(puttyHeaderV1) + case 2: + buf.WriteString(puttyHeaderV2) + case 3, 0: + k.Version = 3 + buf.WriteString(puttyHeaderV3) + default: + return ret, fmt.Errorf("PuTTY key format verion needs to be set to 1, 2, or 3") + } + + switch k.Algo { + case "ssh-rsa", + "ecdsa-sha2-nistp256", + "ecdsa-sha2-nistp384", + "ecdsa-sha2-nistp521", + "ssh-dss", + "ssh-ed25519": + fmt.Fprintf(buf, ": %s\r\n", k.Algo) + default: + return ret, fmt.Errorf("invalid algorithm") + } + + fmt.Fprintf(buf, "Encryption: %s\r\n", k.Encryption) + + if k.Comment == "" { + k.Comment = "PuTTY key" + } + fmt.Fprintf(buf, "Comment: %s\r\n", noNewLines.Replace(k.Comment)) + + pub := splitByWidth(base64.StdEncoding.EncodeToString(k.PublicKey), 64) + fmt.Fprintf(buf, "Public-Lines: %d\r\n", len(pub)) + fmt.Fprintf(buf, "%s\r\n", strings.Join(pub, "\r\n")) + + if len(k.PrivateKey) > 0 { + if k.KeyDerivation != "" { + fmt.Fprintf(buf, "Key-Derivation: %s\r\n", k.KeyDerivation) + } + if k.Argon2Memory > 0 { + fmt.Fprintf(buf, "Argon2-Memory: %d\r\n", k.Argon2Memory) + } + if k.Argon2Passes > 0 { + fmt.Fprintf(buf, "Argon2-Passes: %d\r\n", k.Argon2Passes) + } + if k.Argon2Parallelism > 0 { + fmt.Fprintf(buf, "Argon2-Parallelism: %d\r\n", k.Argon2Parallelism) + } + if len(k.Argon2Salt) > 0 { + fmt.Fprintf(buf, "Argon2-Salt: %02x\r\n", k.Argon2Salt) + } + + priv := splitByWidth(base64.StdEncoding.EncodeToString(k.PrivateKey), 64) + fmt.Fprintf(buf, "Private-Lines: %d\r\n", len(priv)) + fmt.Fprintf(buf, "%s\r\n", strings.Join(priv, "\r\n")) + + if k.Encryption == "none" { + k.calculateHMAC(nil) + } + if k.Version == 1 && k.Encryption == "none" { + fmt.Fprintf(buf, "Private-Hash: %0x", k.PrivateMac) + } else { + fmt.Fprintf(buf, "Private-MAC: %0x", k.PrivateMac) + } + } + return buf.Bytes(), nil } type reader interface { @@ -116,32 +216,87 @@ func New(b []byte) (*Key, error) { return k, nil } +// SetPrivateKey sets the private key. It supports RSA (PKCS#1), DSA (OpenSSL), ECDSA and ED25519 private keys. +func (k *Key) SetKey(key interface{}) (err error) { + switch PrivateKey := key.(type) { + case *rsa.PrivateKey: + err = k.setRSAPrivateKey(PrivateKey) + case rsa.PrivateKey: + err = k.setRSAPrivateKey(&PrivateKey) + + case *dsa.PrivateKey: + err = k.setDSAPrivateKey(PrivateKey) + case dsa.PrivateKey: + err = k.setDSAPrivateKey(&PrivateKey) + + case *ecdsa.PrivateKey: + err = k.setECDSAPrivateKey(PrivateKey) + case ecdsa.PrivateKey: + err = k.setECDSAPrivateKey(&PrivateKey) + + case *ed25519.PrivateKey: + err = k.setED25519PrivateKey(PrivateKey) + case ed25519.PrivateKey: + err = k.setED25519PrivateKey(&PrivateKey) + default: + return fmt.Errorf("Unknown key type: %T", key) + } + if err == nil { + k.Encryption = "none" + } + return +} + // ParseRawPrivateKey returns a private key from a PuTTY encoded private key. It // supports RSA (PKCS#1), DSA (OpenSSL), ECDSA and ED25519 private keys. -func (k *Key) ParseRawPrivateKey(password []byte) (interface{}, error) { +func (k *Key) ParseRawPrivateKey(password []byte) (ret interface{}, err error) { if k.Encryption != "none" && len(password) == 0 { return nil, fmt.Errorf("expecting password") } - err := k.decrypt(password) + // Fall back if an error happens + priv := append([]byte{}, k.PrivateKey...) + defer func() { + if err != nil { + k.PrivateKey = priv + } else { + if k.Version == 3 { + k.KeyDerivation = "" + k.Argon2Memory = 0 + k.Argon2Passes = 0 + k.Argon2Parallelism = 0 + k.Argon2Salt = []byte{} + } + k.Encryption = "none" + if k.keySize == 0 { + k.keySize = len(k.PrivateKey) + } else { + k.PrivateKey = k.PrivateKey[:k.keySize] + } + k.calculateHMAC(nil) + } + }() + + err = k.decrypt(password) if err != nil { - return nil, err + return } switch k.Algo { case "ssh-rsa": - return k.readRSAPrivateKey() + ret, err = k.readRSAPrivateKey() case "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp521": - return k.readECDSAPrivateKey() + ret, err = k.readECDSAPrivateKey() case "ssh-dss": - return k.readDSAPrivateKey() + ret, err = k.readDSAPrivateKey() case "ssh-ed25519": - return k.readED25519PrivateKey() + ret, err = k.readED25519PrivateKey() + default: + return nil, fmt.Errorf("unsupported key type %q", k.Algo) } - - return nil, fmt.Errorf("unsupported key type %q", k.Algo) + return } // ParseRawPublicKey returns a public key from a PuTTY encoded private key. It @@ -273,7 +428,6 @@ func decodeFields(r reader) (*Key, error) { switch h { case puttyHeaderV1: k.Version = 1 - return nil, fmt.Errorf("PuTTY key format is too old") case puttyHeaderV2: k.Version = 2 case puttyHeaderV3: @@ -298,7 +452,9 @@ func decodeFields(r reader) (*Key, error) { // check the encryption format switch string(b) { case "none": + k.padded = false case "aes256-cbc": + k.padded = true default: return nil, fmt.Errorf("invalid encryption format: %s", b) } @@ -431,7 +587,7 @@ func decryptCBC(cipherKey, cipherIV, macKey, ciphertext []byte) error { // initialize AES 256 bit cipher cipherBlock, err := aes.NewCipher(cipherKey) if err != nil { - return fmt.Errorf("failed to initialize a cipher block: %v", err) + return fmt.Errorf("failed to initialize a cipher block for decrypt: %v", err) } // decrypt @@ -440,6 +596,72 @@ func decryptCBC(cipherKey, cipherIV, macKey, ciphertext []byte) error { return nil } +func encryptCBC(cipherKey, cipherIV, macKey, ciphertext []byte) error { + if len(ciphertext) < aes.BlockSize { + return fmt.Errorf("ciphertext is too short") + } + + if len(ciphertext)%aes.BlockSize != 0 { + return fmt.Errorf("ciphertext is not a multiple of the block size") + } + + // initialize AES 256 bit cipher + cipherBlock, err := aes.NewCipher(cipherKey) + if err != nil { + return fmt.Errorf("failed to initialize a cipher block for encrypt: %v", err) + } + + // encrypt + cipher.NewCBCEncrypter(cipherBlock, cipherIV).CryptBlocks(ciphertext, ciphertext) + + return nil +} + +// calculateHMAC calculates PuTTY key HMAC with a hash function +func (k *Key) calculateHMAC(password []byte) error { + _, _, macKey, err := k.deriveKeys(password) + if err != nil { + return err + } + keyCut := len(k.PrivateKey) + if k.Encryption == "none" { + if k.keySize == 0 { + k.keySize = len(k.PrivateKey) + } + keyCut = k.keySize + } + var hashFunc hash.Hash + switch k.Version { + case 1: + if k.Encryption == "none" { + k.PrivateMac = sha1.New().Sum(k.PrivateKey) + return nil + } else { + hashFunc = hmac.New(sha1.New, addPadding(k.PrivateMac)) + } + case 2: + hashFunc = hmac.New(sha1.New, macKey) + case 3: + hashFunc = hmac.New(sha256.New, macKey) + default: + return fmt.Errorf("unknown key format version: %d", k.Version) + } + + binary.Write(hashFunc, binary.BigEndian, uint32(len(k.Algo))) + hashFunc.Write([]byte(k.Algo)) + binary.Write(hashFunc, binary.BigEndian, uint32(len(k.Encryption))) + hashFunc.Write([]byte(k.Encryption)) + binary.Write(hashFunc, binary.BigEndian, uint32(len(k.Comment))) + hashFunc.Write([]byte(k.Comment)) + binary.Write(hashFunc, binary.BigEndian, uint32(len(k.PublicKey))) + hashFunc.Write(k.PublicKey) + binary.Write(hashFunc, binary.BigEndian, uint32(len(k.PrivateKey[:keyCut]))) + hashFunc.Write(k.PrivateKey) + + k.PrivateMac = hashFunc.Sum(nil) + return nil +} + // validateHMAC validates PuTTY key HMAC with a hash function func (k Key) validateHMAC(hashFunc hash.Hash) error { binary.Write(hashFunc, binary.BigEndian, uint32(len(k.Algo))) @@ -473,7 +695,7 @@ func (k Key) deriveKeys(password []byte) ([]byte, []byte, []byte, error) { macKey := sha1sum.Sum(nil) var seq int - var k []byte + var kb []byte // calculate and combine sha1 sums of each seq+password, // then truncate them to a 32 bytes (256 bit CBC) key @@ -481,21 +703,21 @@ func (k Key) deriveKeys(password []byte) ([]byte, []byte, []byte, error) { t := []byte{0, 0, 0, byte(seq)} t = append(t, password...) h := sha1.Sum(t) - k = append(k, h[:]...) - if len(k) >= 32 { + kb = append(kb, h[:]...) + if len(kb) >= 32 { break } seq++ } - if len(k) < cipherKeyLength { + if len(kb) < cipherKeyLength { return nil, nil, nil, fmt.Errorf("invalid length of the calculated cipher key") } // zero IV cipherIV := make([]byte, aes.BlockSize) - return k[:cipherKeyLength], cipherIV, macKey, nil + return kb[:cipherKeyLength], cipherIV, macKey, nil } var h []byte @@ -519,24 +741,33 @@ func (k Key) deriveKeys(password []byte) ([]byte, []byte, []byte, error) { nil } -// decrypt decrypts the key, when it is encrypted. and validates its signature -func (k *Key) decrypt(password []byte) error { +// Decrypt decrypts the key, when it is encrypted. and validates its signature +func (k *Key) decrypt(password []byte) (err error) { cipherKey, cipherIV, macKey, err := k.deriveKeys(password) if err != nil { return err } // decrypt the key, when it is encrypted - if !k.decrypted && k.Encryption != "none" { + if k.Encryption != "none" { err = decryptCBC(cipherKey, cipherIV, macKey, k.PrivateKey) if err != nil { return err } } - k.decrypted = true // validate key signature switch k.Version { + case 1: + if k.Encryption == "none" { + h := sha1.New().Sum(k.PrivateKey) + if !bytes.Equal(h, k.PrivateMac) { + return fmt.Errorf("calculated SHA1 sum %q doesn't correspond to %q", hex.EncodeToString(h), hex.EncodeToString(k.PrivateMac)) + } + return nil + } else { + err = k.validateHMAC(hmac.New(sha1.New, addPadding(k.PrivateMac))) + } case 2: err = k.validateHMAC(hmac.New(sha1.New, macKey)) case 3: @@ -544,9 +775,53 @@ func (k *Key) decrypt(password []byte) error { default: err = fmt.Errorf("unknown key format version: %d", k.Version) } + return +} + +// Encrypt encrypts the key and updates the HMAC +func (k *Key) Encrypt(random io.Reader, password []byte) error { + // Set a sensible value for an unset version number + if k.Version == 0 { + k.Version = 3 + } else if k.Version > 3 || k.Version < 0 { + return fmt.Errorf("unknown putty key version") + } + if k.Encryption != "none" { + return fmt.Errorf("decrypt the key first, then encrypt it") + } + if len(password) == 0 { + return fmt.Errorf("no password provided") + } + if k.keySize == 0 { + k.keySize = len(k.PrivateKey) + } + k.PrivateKey = addPadding(k.PrivateKey) + k.padded = true + + if k.Version == 3 && k.KeyDerivation == "" { + k.KeyDerivation = "Argon2id" + k.Argon2Memory = 8192 + k.Argon2Passes = 13 + k.Argon2Parallelism = 1 + salt := make([]byte, 16) + random.Read(salt) + k.Argon2Salt = salt + } + + cipherKey, cipherIV, macKey, err := k.deriveKeys(password) if err != nil { return err } + k.Encryption = "aes256-cbc" + err = k.calculateHMAC(password) + if err != nil { + return err + } + + err = encryptCBC(cipherKey, cipherIV, macKey, k.PrivateKey) + if err != nil { + return err + } return nil } diff --git a/putty_test.go b/putty_test.go index e58b57c..102dbf9 100644 --- a/putty_test.go +++ b/putty_test.go @@ -9,7 +9,6 @@ import ( "crypto/elliptic" "crypto/rsa" "encoding/hex" - "fmt" "math/big" "reflect" "strings" @@ -431,7 +430,7 @@ Comment: a@b` reader = strings.NewReader(privateKeyContent) _, err = decodeFields(bufio.NewReader(reader)) - if err == nil { + if err != nil { t.Errorf("Should have identified old key format") } @@ -457,6 +456,44 @@ func TestKey_Load(t *testing.T) { validateFields(t, key, mac1) } +func TestKey_WriteRsa(t *testing.T) { + key := &Key{ + Version: 2, + Algo: "ssh-rsa", + Encryption: "none", + Comment: "a@b", + PrivateMac: mac1, + } + N, _ := big.NewInt(0).SetString("8899603823917657073993455955247192742235475556705611494844755637605909744142362852675956970710170381872286959792213615771125246580768745178396163419345147", 10) + D, _ := big.NewInt(0).SetString("1202649165394277982972088642600971992193983183338596147951994005081879695154347826520147831485495538539062697907166363205375939469504219407479989786042173", 10) + P1, _ := big.NewInt(0).SetString("99432035046298583655632173300307896812294295981612330336064741365252891422281", 10) + P2, _ := big.NewInt(0).SetString("89504391816719133847764509922687382370233755361682176101456821678986111210787", 10) + Qinv, _ := big.NewInt(0).SetString("58739399210597955282522708421744401703943010346410985856809409636285000886096", 10) + key.setRSAPrivateKey(&rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + E: 37, + N: N, + }, + D: D, + Primes: []*big.Int{ + P1, + P2, + }, + Precomputed: rsa.PrecomputedValues{ + Qinv: Qinv, + }, + }) + + b, err := key.Marshal() + if err != nil { + t.Errorf("error marshalling Key fields: %v", err) + return + } + if !bytes.Equal([]byte(keyContent), cleanR(b)) { + t.Errorf("got:\n%s\n\nexpected:\n%s\n\n", string(b), keyContent) + } +} + func TestNew(t *testing.T) { key, err := New([]byte(keyContent)) if err != nil { @@ -491,6 +528,7 @@ func validateFields(t *testing.T, key *Key, mac []byte) { t.Error(err) return } + //fmt.Printf("k=%+v\n", k) } func TestParseRawPrivateKey(t *testing.T) { @@ -501,8 +539,14 @@ func TestParseRawPrivateKey(t *testing.T) { continue } + enc, err := key.Marshal() + if strings.TrimSpace(fixture.content) != strings.ReplaceAll( + strings.TrimSpace(string(enc)), "\r", "") { + } + v, err := key.ParseRawPrivateKey(fixture.password) if err != nil { + //t.Errorf("Expect:\n%s\nGot:\n%s", fixture.content, string(enc)) t.Errorf("error decrypting key #%d: %v", i, err) continue } @@ -510,12 +554,44 @@ func TestParseRawPrivateKey(t *testing.T) { switch v := v.(type) { case *dsa.PrivateKey, *rsa.PrivateKey, *ecdsa.PrivateKey, *ed25519.PrivateKey: if !reflect.DeepEqual(v, fixture.data) { - fmt.Printf("RSA key: %#+v\n", v) t.Errorf("error verifying a %T key #%d", v, i) } default: t.Errorf("unknown %T key #%d type", v, i) } + // Key is fully decoded now + + // The key is decrypted now, let's re-encrypt to verify that all goes back to the same thing + if fixture.password != nil { + salt, err := hex.DecodeString(string("745d60746c67666afa47dbf23226c6c9")) + err = key.Encrypt(bytes.NewReader(salt), fixture.password) + if err != nil { + t.Errorf("error encrypting key #%d: %v", i, err) + continue + } + } + if key.Encryption == "none" { + salt, err := hex.DecodeString(string("745d60746c67666afa47dbf23226c6c9")) + err = key.Encrypt(bytes.NewReader(salt), []byte("testagain")) + if err != nil { + t.Errorf("error encrypting key #%d: %v", i, err) + continue + } + _, err = key.ParseRawPrivateKey([]byte("testagain")) + if err != nil { + t.Errorf("error decrypting key #%d: %v", i, err) + continue + } + } + + enc, err = key.Marshal() + if err != nil { + t.Errorf("error marshalling key #%d: %v", i, err) + continue + } + if !bytes.Equal(cleanR([]byte(fixture.content)), cleanR(enc)) { + t.Errorf("Expect:\n%s\nGot:\n%s", fixture.content, string(enc)) + } } } @@ -564,3 +640,6 @@ func TestLoadFromUrandom(t *testing.T) { } t.Logf("%v", err) } +func cleanR(s []byte) []byte { + return []byte(strings.ReplaceAll(strings.TrimSpace(string(s)), "\r", "")) +} diff --git a/rsa.go b/rsa.go index 5537842..a350e17 100644 --- a/rsa.go +++ b/rsa.go @@ -12,7 +12,7 @@ func (k Key) readRSAPublicKey() (*rsa.PublicKey, error) { E *big.Int // pub exponent N *big.Int // pub modulus } - err := unmarshal(k.PublicKey, &pub, false) + _, err := unmarshal(k.PublicKey, &pub, false) if err != nil { return nil, err } @@ -29,7 +29,21 @@ func (k Key) readRSAPublicKey() (*rsa.PublicKey, error) { return publicKey, nil } -func (k Key) readRSAPrivateKey() (*rsa.PrivateKey, error) { +func (k *Key) setRSAPublicKey(pk *rsa.PublicKey) (err error) { + var pub struct { + Header string // header + E *big.Int // pub exponent + N *big.Int // pub modulus + } + k.Algo = "ssh-rsa" + pub.Header = "ssh-rsa" + pub.E = big.NewInt(int64(pk.E)) + pub.N = pk.N + k.PublicKey, err = marshal(&pub) + return +} + +func (k *Key) readRSAPrivateKey() (*rsa.PrivateKey, error) { publicKey, err := k.readRSAPublicKey() if err != nil { return nil, err @@ -41,7 +55,7 @@ func (k Key) readRSAPrivateKey() (*rsa.PrivateKey, error) { P2 *big.Int // prime 2 Qinv *big.Int // Qinv } - err = unmarshal(k.PrivateKey, &priv, k.Encryption != "none") + k.keySize, err = unmarshal(k.PrivateKey, &priv, k.padded) if err != nil { return nil, err } @@ -71,3 +85,30 @@ func (k Key) readRSAPrivateKey() (*rsa.PrivateKey, error) { return privateKey, nil } + +func (k *Key) setRSAPrivateKey(pk *rsa.PrivateKey) (err error) { + err = k.setRSAPublicKey(&pk.PublicKey) + if err != nil { + return err + } + + var priv struct { + D *big.Int // private exponent + P1 *big.Int // prime 1 + P2 *big.Int // prime 2 + Qinv *big.Int // Qinv + } + + priv.D = pk.D + priv.P1 = pk.Primes[0] + priv.P2 = pk.Primes[1] + // Make sure p > q + if priv.P1.Cmp(priv.P2) != 1 { + priv.P1, priv.P2 = priv.P2, priv.P1 + } + priv.Qinv = pk.Precomputed.Qinv + k.PrivateKey, err = marshal(&priv) + k.keySize = len(k.PrivateKey) + k.padded = false + return +} diff --git a/ssh1.go b/ssh1.go new file mode 100644 index 0000000..af84340 --- /dev/null +++ b/ssh1.go @@ -0,0 +1,256 @@ +package putty + +import ( + "bytes" + "crypto/des" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "fmt" + "io" + "math/big" + "strings" + + "github.com/pschou/go-cbc3" +) + +var ( + type_SSH1_CIPHER_DES3_CBC3 = byte(3) +) + +func (k *Key) MarshalSSH1WithPassword(password string) ([]byte, error) { + return k.saveSSH1([]byte(password), rand.Reader) +} +func (k *Key) MarshalSSH1() ([]byte, error) { + return k.saveSSH1(nil, rand.Reader) +} + +func (k *Key) saveSSH1(password []byte, rand io.Reader) ([]byte, error) { + if len(k.PrivateKey) == 0 { + // If we don't have a private key, just return the public key + pub, err := k.readRSAPublicKey() + if err != nil { + return nil, err + } + return []byte(fmt.Sprintf("%d %d %s %s\r\n", + pub.N.BitLen(), pub.E, pub.N, k.Comment)), nil + } + + // Parse the private and public keys + priv, err := k.readRSAPrivateKey() + if err != nil { + return nil, err + } + + // Write to buffer then return the contents + var pubBytes, privBytes bytes.Buffer + pubBytes.Write([]byte("SSH PRIVATE KEY FILE FORMAT 1.1\n\x00")) + if password == nil { + pubBytes.Write([]byte{0}) + } else { + pubBytes.Write([]byte{type_SSH1_CIPHER_DES3_CBC3}) + } + pubBytes.Write(make([]byte, 4)) + writeInt32(&pubBytes, uint32(priv.PublicKey.N.BitLen())) + writeBigInt1(&pubBytes, priv.PublicKey.N) + writeBigInt1(&pubBytes, big.NewInt(int64(priv.PublicKey.E))) + writeString2(&pubBytes, k.Comment) + + ab := make([]byte, 2) + rand.Read(ab) + privBytes.Write(ab) + privBytes.Write(ab) + + writeBigInt1(&privBytes, priv.D) + writeBigInt1(&privBytes, priv.Precomputed.Qinv) + writeBigInt1(&privBytes, priv.Primes[1]) + writeBigInt1(&privBytes, priv.Primes[0]) + + privB := privBytes.Bytes() + if len(privB)%8 > 0 { + privB = append(privB, make([]byte, 8-len(privB)%8)...) + } + if password != nil { + // Decrypt the private porition of the key + hash := md5.Sum(password) + key := hash[:] + c1, err := des.NewCipher(key[:8]) + if err != nil { + return nil, fmt.Errorf("Unable to build DES block 1, %s", err) + } + c2, err := des.NewCipher(key[8:]) + if err != nil { + return nil, fmt.Errorf("Unable to build DES block 2, %s", err) + } + c3, err := des.NewCipher(key[:8]) + if err != nil { + return nil, fmt.Errorf("Unable to build DES block 3, %s", err) + } + crypter := cbc3.NewEncrypter(c1, c2, c3, make([]byte, 24)) + crypter.CryptBlocks(privB, privB) + } + + //if !strings.HasPrefix(string(b), "SSH PRIVATE KEY FILE FORMAT 1.1\n") { + return append(pubBytes.Bytes(), privB...), nil +} + +func (k *Key) LoadSSH1WithPassword(b []byte, password string) error { + if !strings.HasPrefix(string(b), "SSH PRIVATE KEY FILE FORMAT 1.1\n") { + return fmt.Errorf("Expected: SSH PRIVATE KEY FILE FORMAT 1.1") + } + return k.loadSSH1(b, []byte(password)) +} +func (k *Key) LoadSSH1(b []byte) error { + if strings.HasPrefix(string(b), "SSH PRIVATE KEY FILE FORMAT 1.1\n") { + // If this file is a private key file + return k.loadSSH1(b, nil) + } else if parts := strings.SplitN(string(b), " ", 4); len(parts) == 4 { + // If this file is a public key file + e, _ := new(big.Int).SetString(parts[1], 10) + if e == nil { + return fmt.Errorf("Unable to read SSH1 exponent") + } + m, _ := new(big.Int).SetString(parts[2], 10) + if m == nil { + return fmt.Errorf("Unable to read SSH1 modulus") + } + k.Comment = strings.TrimSuffix(parts[3], "\r\n") + + return k.setRSAPublicKey(&rsa.PublicKey{ + N: m, + E: int(e.Int64()), + }) + } + return fmt.Errorf("Unknown SSH1 key file format") +} + +func (k *Key) loadSSH1(b, password []byte) error { + var encrypted bool + switch b[33] { + case 0: + if password != nil { + return fmt.Errorf("Password provided, but encryption flag is not set") + } + case type_SSH1_CIPHER_DES3_CBC3: + if password == nil { + return fmt.Errorf("Password not provided, but encryption flag is set") + } + encrypted = true + default: + return fmt.Errorf("Unsupported encryption %d", b[33]) + } + + src := bytes.NewReader(b[34:]) + { + zero, err := readInt32(src) + if err != nil { + return fmt.Errorf("Unable to read 4 byte padding, %s", err) + } + if zero != 0 { + return fmt.Errorf("Expected zero padding, got %d", zero) + } + } + bits, err := readInt32(src) + if err != nil { + return fmt.Errorf("Err reading the number of key bits (got %q), %s", bits, err) + } + + // This reader only expects the MODULUS to be first + var e, m *big.Int + m, err = readBigInt1(src) + if err != nil { + return fmt.Errorf("Unable to read SSH1 modulus, %s", err) + } + e, err = readBigInt1(src) + if err != nil { + return fmt.Errorf("Unable to read SSH1 exponent, %s", err) + } + + k.Comment, err = readString2(src) + if err != nil { + return fmt.Errorf("Unable to read SSH1 comment, %s", err) + } + + if encrypted { + // Decrypt the private porition of the key + hash := md5.Sum(password) + key := hash[:] + c1, err := des.NewCipher(key[:8]) + if err != nil { + return fmt.Errorf("Unable to build DES block 1, %s", err) + } + c2, err := des.NewCipher(key[8:]) + if err != nil { + return fmt.Errorf("Unable to build DES block 2, %s", err) + } + c3, err := des.NewCipher(key[:8]) + if err != nil { + return fmt.Errorf("Unable to build DES block 3, %s", err) + } + pos := len(b) - src.Len() + crypter := cbc3.NewDecrypter(c1, c2, c3, make([]byte, 24)) + crypter.CryptBlocks(b[pos:], b[pos:]) + + src = bytes.NewReader(b[pos:]) + } + + // Check the first 4 bytes for the pattern + { + ab := make([]byte, 4) + src.Read(ab) + + if ab[0] != ab[2] || ab[1] != ab[3] { + return fmt.Errorf("Expected [a] [b] [a] [b] values, wrong passphrase?") + } + } + + // Read in the private porition of the key + var private_exponent, iqmp, p, q *big.Int + private_exponent, err = readBigInt1(src) + if err != nil { + return fmt.Errorf("Unable to read SSH1 Private Exponent, %s", err) + } + + iqmp, err = readBigInt1(src) + if err != nil { + return fmt.Errorf("Unable to read SSH1 Qinv, %s", err) + } + q, err = readBigInt1(src) + if err != nil { + return fmt.Errorf("Unable to read SSH1 Q, %s", err) + } + p, err = readBigInt1(src) + if err != nil { + return fmt.Errorf("Unable to read SSH1 P, %s", err) + } + + /* + * Verify that the public data in an RSA key matches the private + * data. We also check the private data itself: we ensure that p > + * q and that iqmp really is the inverse of q mod p. + */ + + priv := rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: m, + E: int(e.Int64()), + }, + D: private_exponent, + Primes: []*big.Int{p, q}, + Precomputed: rsa.PrecomputedValues{ + Qinv: iqmp, + }, + } + + if err := priv.Validate(); err != nil { + return fmt.Errorf("Validation of key failed: %s", err) + } + + return k.setRSAPrivateKey(&priv) +} + +func des_xor(a, b []byte) { + for i := range a { + a[i] ^= b[i] + } +} diff --git a/ssh1_test.go b/ssh1_test.go new file mode 100644 index 0000000..75ecd11 --- /dev/null +++ b/ssh1_test.go @@ -0,0 +1,88 @@ +package putty + +import ( + "bytes" + "encoding/base64" + "fmt" + "strings" + "testing" +) + +var SSH1unencrypted = ` + U1NIIFBSSVZBVEUgS0VZIEZJTEUgRk9STUFUIDEuMQoAAAAAAAAAAAQABACyVhLTcHAKqu8YkxMR + fuq2ZtvfBAZ/ZD+TT6+sjhLSTQ+YjO2twb3Ku8eYiTKFcT40mSaMhq0Ei9YG1iGyLdDJLUF4s4HO + ua138J1SQJac1BDzWBy+PUqoeRk2TuvvwVFAUZ8ZlMz8suw7WvWWYnkqPVCCiHVDNLm9awpBP1y8 + lQAGJQAAABByc2Eta2V5LTIwMjIwOTE4u3i7eAP+MDLwVNJHy4gk8eKPiDAjwphXGa4PmA1BnW97 + lmuWYloENxFU/odjukCWz0esyh6bMM9yM9FfMalAw5PRwXQqlsg6xz/LY7tguSxtlBDxwluOnCcv + 7EereEdcSGTTB5iEZOcxRJrvMgIUQlHSzc9uDAnPslFOuADLrYeR7/uUhkUB/jGTA5jF6NDanWN2 + xyRCO1dumGfkhexTqGRA5WtUz/DxHZch1iD3Ek8dC0OF2jr7f8Ig3PbC4RfHyNd0pOZNTK4CALwe + cv5UIIslQKzt3POQpi3P8YAaE7oed/Di6m/325lS6AB4bfWIaAQeywg9Nep/meiS1AuDslwBl/5s + +JbXC5MCAPKv8Ukjifk78IWznkGHdFN+Jno3wEbLWeaUDNBNq0B64vm9LchpKHPo4VcsZvh7/WOj + mraBgaKTVpCa6lJ5wDcAAAAA` +var SSH1encrypted = ` + U1NIIFBSSVZBVEUgS0VZIEZJTEUgRk9STUFUIDEuMQoAAwAAAAAAAAQABACyVhLTcHAKqu8YkxMR + fuq2ZtvfBAZ/ZD+TT6+sjhLSTQ+YjO2twb3Ku8eYiTKFcT40mSaMhq0Ei9YG1iGyLdDJLUF4s4HO + ua138J1SQJac1BDzWBy+PUqoeRk2TuvvwVFAUZ8ZlMz8suw7WvWWYnkqPVCCiHVDNLm9awpBP1y8 + lQAGJQAAABByc2Eta2V5LTIwMjIwOTE4jGWS/2YMLF+EayIjJtsJYvfV5ZRhfWwvW6uZm9I+6Qyq + Jg2Rts81YB7iwlMBBEWxdHi+gOIx3p5RpP48QlXGXnv/8vv62yR/iadL802Rto6uIwN9WA8KGZ/a + +pe64e8xa3sYX9622XCT4pA8lB3Mb9+AiBzra+GSH8wLlU6k9IZusvCwK+/ToBlFCrWAeKLKHNBK + VuR2QjspFldSXj46AsUmTrFYgATQHCW8BkfMtZFYTFFi+ZkgrZMOM2hg0p4gVMNVw5YQLPdiyLjm + SKxOEFB/z1YygVd5PKS9rF3fw2UeSSXq02hoGEotZwmRMa7QAN4hJ7N/8KlDB9M1768mcOY9TD2j + Dv3NsaCgX0rD8+juS+L59QZyP9gOcOSIPq2o5etDcDKdZFPLDYKqAbKQK/As/5+1WRXfLy/XjTfN + Psg/DuQZf57RNQ3+y9wy2yqK` + +var SSH1Public = "000000077373682d72736100000001250000008100b25612d370700aaaef189313117eeab666dbdf04067f643f934fafac8e12d24d0f988cedadc1bdcabbc798893285713e3499268c86ad048bd606d621b22dd0c92d4178b381ceb9ad77f09d5240969cd410f3581cbe3d4aa87919364eebefc15140519f1994ccfcb2ec3b5af59662792a3d508288754334b9bd6b0a413f5cbc95" +var SSH1Private = "000000803032f054d247cb8824f1e28f883023c2985719ae0f980d419d6f7b966b96625a04371154fe8763ba4096cf47acca1e9b30cf7233d15f31a940c393d1c1742a96c83ac73fcb63bb60b92c6d9410f1c25b8e9c272fec47ab78475c4864d307988464e731449aef3202144251d2cdcf6e0c09cfb2514eb800cbad8791effb9486450000004100f2aff1492389f93bf085b39e418774537e267a37c046cb59e6940cd04dab407ae2f9bd2dc8692873e8e1572c66f87bfd63a39ab68181a29356909aea5279c0370000004100bc1e72fe54208b2540aceddcf390a62dcff1801a13ba1e77f0e2ea6ff7db9952e800786df58868041ecb083d35ea7f99e892d40b83b25c0197fe6cf896d70b930000004031930398c5e8d0da9d6376c724423b576e9867e485ec53a86440e56b54cff0f11d9721d620f7124f1d0b4385da3afb7fc220dcf6c2e117c7c8d774a4e64d4cae" +var SSH1PublicKey = "1024 37 125231955839861145597959941566719353978850591507101309939104935173887041843256466484940189588914165868561229567474080040371609970058249262542234488820048185261582821399489992085432366135601216753217583542250921463784984173203393090111095957397902456223378197593025272833066640581329273146588092056412215884949 rsa-key-20220918\r\n" + +func TestSSH1LoadEncrypted(t *testing.T) { + k := Key{} + err := k.LoadSSH1WithPassword(b64decode(SSH1encrypted), "testit") + if err != nil { + panic(err) + } + + if fmt.Sprintf("%02x", k.PublicKey) != SSH1Public { + panic("Error decoding public key") + } + if fmt.Sprintf("%02x", k.PrivateKey) != SSH1Private { + panic("Error decoding private key") + } + + priv, _ := k.saveSSH1([]byte("testit"), bytes.NewReader([]byte{111, 130})) + if !bytes.Equal(priv, b64decode(SSH1encrypted)) { + panic("Encoded SSH1 mismatch") + } + + priv, _ = k.saveSSH1(nil, bytes.NewReader([]byte{187, 120})) + if !bytes.Equal(priv, b64decode(SSH1unencrypted)) { + panic("Encoded SSH1 mismatch") + } + + k.PrivateKey = []byte{} + pub, _ := k.MarshalSSH1() + if string(pub) != SSH1PublicKey { + panic("Error encoding public key") + } +} + +func TestSSH1Load(t *testing.T) { + k := Key{} + err := k.LoadSSH1(b64decode(SSH1unencrypted)) + if err != nil { + panic(err) + } + + if fmt.Sprintf("%02x", k.PublicKey) != SSH1Public { + panic("Error decoding public key") + } + if fmt.Sprintf("%02x", k.PrivateKey) != SSH1Private { + panic("Error decoding private key") + } +} + +func b64decode(str string) []byte { + noWhiteSpace := strings.NewReplacer("\r", "", "\n", "", "\t", "", " ", "") + dat, _ := base64.StdEncoding.DecodeString(noWhiteSpace.Replace(str)) + return dat +} diff --git a/unmarshal.go b/unmarshal.go index b7b9598..2bbf97c 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -10,22 +10,36 @@ import ( "reflect" ) -func unmarshal(data []byte, val interface{}, enc bool) error { +func unmarshal(data []byte, val interface{}, padded bool) (keysize int, err error) { v := reflect.ValueOf(val).Elem() buf := bytes.NewReader(data) - err := parseField(v, buf) + err = parseField(v, buf) if err != nil { - return err + return } - // check key block size - err = checkGarbage(buf, enc) + // get the actual keysize + var size int64 + size, err = buf.Seek(0, io.SeekCurrent) + keysize = int(size) if err != nil { - return fmt.Errorf("wrong key size: %s", err) + return } - return nil + // check key block size + paddedSize := size + if padded { + // normalize the size of the decrypted part (should be % aes.BlockSize) + paddedSize = size + aes.BlockSize - size&(aes.BlockSize-1) + } + + // check key size + if buf.Size() != paddedSize { + return 0, fmt.Errorf("wrong key size, expected %d, got %d", paddedSize, buf.Size()) + } + + return } func parseField(v reflect.Value, src *bytes.Reader) error { @@ -48,19 +62,19 @@ func parseField(v reflect.Value, src *bytes.Reader) error { switch fieldType { case reflect.TypeOf(string("")): - parsedString, err := readString(src) + parsedString, err := readString2(src) if err != nil { return err } v.Set(reflect.ValueOf(parsedString)) case reflect.TypeOf([]byte(nil)): - parsedBytes, err := readBytes(src) + parsedBytes, err := readBytes2(src) if err != nil { return err } v.Set(reflect.ValueOf(parsedBytes)) case reflect.TypeOf(new(big.Int)): - parsedInt, err := readBigInt(src) + parsedInt, err := readBigInt2(src) if err != nil { return err } @@ -71,14 +85,23 @@ func parseField(v reflect.Value, src *bytes.Reader) error { return nil } +func readInt32(src *bytes.Reader) (uint32, error) { + var val uint32 + // read 4 bytes + + err := binary.Read(src, binary.BigEndian, &val) + return val, err +} + +func readBytes1(src *bytes.Reader) ([]byte, error) { + var length uint16 + // read 2 bytes (uint16 size) in bits of the next element -func readBytes(src *bytes.Reader) ([]byte, error) { - var length uint32 - // read 4 bytes (uint32 size) of the next element size err := binary.Read(src, binary.BigEndian, &length) if err != nil { return nil, err } + length = (length + 7) / 8 // get the current reader position pos, err := src.Seek(0, io.SeekCurrent) @@ -88,7 +111,10 @@ func readBytes(src *bytes.Reader) ([]byte, error) { // check next element size if int64(length)+pos > src.Size() { - return nil, fmt.Errorf("the element length %d is out of range", length) + /* SSH-1.5 spec says that it's OK for the prefix uint16 to be + * _greater_ than the actual number of bits */ + length = uint16(src.Size() - pos) + //return nil, fmt.Errorf("the element length %d is out of range", length) } buf := make([]byte, length) @@ -104,39 +130,71 @@ func readBytes(src *bytes.Reader) ([]byte, error) { return buf, nil } -func readString(src *bytes.Reader) (string, error) { - b, err := readBytes(src) +func readBigInt1(src *bytes.Reader) (*big.Int, error) { + b, err := readBytes1(src) if err != nil { - return "", err + return nil, err } - return string(b), nil + return new(big.Int).SetBytes(b), nil } -func readBigInt(src *bytes.Reader) (*big.Int, error) { - b, err := readBytes(src) +func readBigIntRemaining(src *bytes.Reader) (*big.Int, error) { + var b bytes.Buffer + _, err := b.ReadFrom(src) if err != nil { return nil, err } - return new(big.Int).SetBytes(b), nil + return new(big.Int).SetBytes(b.Bytes()), nil } -func checkGarbage(src *bytes.Reader, encrypted bool) error { +func readBytes2(src *bytes.Reader) ([]byte, error) { + var length uint32 + // read 4 bytes (uint32 size) of the next element size + err := binary.Read(src, binary.BigEndian, &length) + if err != nil { + return nil, err + } + + // get the current reader position pos, err := src.Seek(0, io.SeekCurrent) if err != nil { - return err + return nil, err + } + + // check next element size + if int64(length)+pos > src.Size() { + return nil, fmt.Errorf("the element length %d is out of range", length) + } + + buf := make([]byte, length) + n, err := io.ReadFull(src, buf) + if err != nil { + return nil, err } - if encrypted { - // normalize the size of the decrypted part (should be % aes.BlockSize) - pos = pos + aes.BlockSize - pos&(aes.BlockSize-1) + if n != int(length) { + return nil, fmt.Errorf("expected to read %d, but read %d", length, n) } - // check key size - if src.Size() != pos { - return fmt.Errorf("expected %d, got %d", pos, src.Size()) + return buf, nil +} + +func readString2(src *bytes.Reader) (string, error) { + b, err := readBytes2(src) + if err != nil { + return "", err } - return nil + return string(b), nil +} + +func readBigInt2(src *bytes.Reader) (*big.Int, error) { + b, err := readBytes2(src) + if err != nil { + return nil, err + } + + return new(big.Int).SetBytes(b), nil }