From 18c4c418b73e8f3a609aadd14ca7f1ac81b872e9 Mon Sep 17 00:00:00 2001 From: Gerson Alexander Pardo Gamez Date: Sat, 18 Dec 2021 09:43:49 -0500 Subject: [PATCH] Fixed header by key type --- rsa/codec.go | 95 +++++++++++++++++++++++------------ rsa/common.go | 47 +++++++++++++---- rsa/convert_pkcs12_keypair.go | 6 ++- rsa/encrypt_privatekey.go | 11 ++-- 4 files changed, 111 insertions(+), 48 deletions(-) diff --git a/rsa/codec.go b/rsa/codec.go index f5a2638..f4c710e 100644 --- a/rsa/codec.go +++ b/rsa/codec.go @@ -1,6 +1,7 @@ package rsa import ( + "bufio" "bytes" "crypto" "crypto/md5" @@ -30,27 +31,19 @@ const ( PrivateKeyFormatTypePKCS8 ) -func getPrivateKeyFormatType(format string) PrivateKeyFormatType { - switch format { - case "pkcs8": - return PrivateKeyFormatTypePKCS8 - case "pkcs1": - return PrivateKeyFormatTypePKCS1 - default: - return PrivateKeyFormatTypePKCS1 - } -} +type HeaderPublicKeyType string -func getPublicKeyFormatType(format string) PublicKeyFormatType { - switch format { - case "pkix": - return PublicKeyFormatTypePKIX - case "pkcs1": - return PublicKeyFormatTypePKCS1 - default: - return PublicKeyFormatTypePKCS1 - } -} +const ( + HeaderPublicKeyPKCS1 = HeaderPublicKeyType("RSA PUBLIC KEY") + HeaderPublicKeyPKIX = HeaderPublicKeyType("PUBLIC KEY") +) + +type HeaderPrivateKeyType string + +const ( + HeaderPrivateKeyPKCS1 = HeaderPrivateKeyType("RSA PRIVATE KEY") + HeaderPrivateKeyPKCS8 = HeaderPrivateKeyType("PRIVATE KEY") +) func getSaltLength(length string) int { switch length { @@ -117,6 +110,28 @@ func getPEMCipher(name string) x509.PEMCipher { } } +func getPublicKeyHeaderByType(formatType PublicKeyFormatType) string { + switch formatType { + case PublicKeyFormatTypePKIX: + return string(HeaderPublicKeyPKIX) + case PublicKeyFormatTypePKCS1: + fallthrough + default: + return string(HeaderPublicKeyPKCS1) + } +} + +func getPrivateKeyHeaderByType(formatType PrivateKeyFormatType) string { + switch formatType { + case PrivateKeyFormatTypePKCS8: + return string(HeaderPrivateKeyPKCS8) + case PrivateKeyFormatTypePKCS1: + fallthrough + default: + return string(HeaderPrivateKeyPKCS1) + } +} + func encodePublicKey(publicKey interface{}, formatType PublicKeyFormatType) ([]byte, error) { var pemBytes []byte @@ -140,12 +155,12 @@ func encodePublicKey(publicKey interface{}, formatType PublicKeyFormatType) ([]b } break } - return pem.EncodeToMemory( + return encodePem( &pem.Block{ - Type: "PUBLIC KEY", + Type: getPublicKeyHeaderByType(formatType), Bytes: pemBytes, }, - ), nil + ) } func encodePrivateKey(privateKey interface{}, formatType PrivateKeyFormatType) ([]byte, error) { @@ -173,12 +188,25 @@ func encodePrivateKey(privateKey interface{}, formatType PrivateKeyFormatType) ( break } - return pem.EncodeToMemory( + return encodePem( &pem.Block{ - Type: "RSA PRIVATE KEY", + Type: getPrivateKeyHeaderByType(formatType), Bytes: pemBytes, }, - ), nil + ) +} + +// encodePem thanks to @patachi +func encodePem(b *pem.Block) ([]byte, error) { + buf := new(bytes.Buffer) + w := bufio.NewWriter(buf) + if err := pem.Encode(w, b); err != nil { + return nil, err + } + if err := w.Flush(); err != nil { + return nil, err + } + return buf.Bytes(), nil } func publicFromPrivate(privateKey interface{}) (interface{}, error) { @@ -191,9 +219,8 @@ func publicFromPrivate(privateKey interface{}) (interface{}, error) { } -func encodeCertificate(certificate *x509.Certificate) []byte { - - return pem.EncodeToMemory( +func encodeCertificate(certificate *x509.Certificate) ([]byte, error) { + return encodePem( &pem.Block{ Type: "CERTIFICATE", Bytes: certificate.Raw, @@ -211,8 +238,14 @@ func encodeToPEMBase64(input []byte) ([]byte, error) { if _, err := b64.Write(input); err != nil { return nil, err } - b64.Close() - breaker.Close() + err := b64.Close() + if err != nil { + return nil, err + } + err = breaker.Close() + if err != nil { + return nil, err + } return out.Bytes(), nil diff --git a/rsa/common.go b/rsa/common.go index b525abf..55e53b1 100644 --- a/rsa/common.go +++ b/rsa/common.go @@ -2,6 +2,7 @@ package rsa import ( "crypto/rsa" + "fmt" "golang.org/x/crypto/pkcs12" "crypto/x509" @@ -12,36 +13,60 @@ import ( func (r *FastRSA) readPrivateKey(privateKey string) (*rsa.PrivateKey, error) { - privateBlock, _ := pem.Decode([]byte(privateKey)) - if privateBlock == nil { - return nil, errors.New("invalid private key") + block, _ := pem.Decode([]byte(privateKey)) + if block == nil { + return nil, fmt.Errorf("invalid private key") } - if privateKeyCert, err := x509.ParsePKCS1PrivateKey(privateBlock.Bytes); err == nil { + switch block.Type { + case string(HeaderPrivateKeyPKCS1): + if privateKeyCert, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return privateKeyCert, nil + } + case string(HeaderPrivateKeyPKCS8): + if privateKeyCert, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + return privateKeyCert.(*rsa.PrivateKey), nil + } + } + + // TODO remove this in the future because we need to use block.type instead + if privateKeyCert, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { return privateKeyCert, nil } - if privateKeyCert, err := x509.ParsePKCS8PrivateKey(privateBlock.Bytes); err == nil { + if privateKeyCert, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { return privateKeyCert.(*rsa.PrivateKey), nil } - return nil, errors.New("x509: unknown format") + return nil, fmt.Errorf("x509: unknown format for privateKey: %s", block.Type) } func (r *FastRSA) readPublicKey(publicKey string) (*rsa.PublicKey, error) { - publicBlock, _ := pem.Decode([]byte(publicKey)) - if publicBlock == nil { + block, _ := pem.Decode([]byte(publicKey)) + if block == nil { return nil, errors.New("invalid public key") } - if publicKeyCert, err := x509.ParsePKCS1PublicKey(publicBlock.Bytes); err == nil { + switch block.Type { + case string(HeaderPublicKeyPKCS1): + if publicKeyCert, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil { + return publicKeyCert, nil + } + case string(HeaderPublicKeyPKIX): + if publicKeyCert, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil { + return publicKeyCert.(*rsa.PublicKey), nil + } + } + + // TODO remove this in the future because we need to use block.type instead + if publicKeyCert, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil { return publicKeyCert, nil } - if publicKeyCert, err := x509.ParsePKIXPublicKey(publicBlock.Bytes); err == nil { + if publicKeyCert, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil { return publicKeyCert.(*rsa.PublicKey), nil } - return nil, errors.New("x509: unknown format") + return nil, fmt.Errorf("x509: unknown format for publicKey: %s", block.Type) } func (r *FastRSA) readPKCS12(data, password string) (interface{}, *x509.Certificate, error) { diff --git a/rsa/convert_pkcs12_keypair.go b/rsa/convert_pkcs12_keypair.go index 0f9d8a1..8fc87ea 100644 --- a/rsa/convert_pkcs12_keypair.go +++ b/rsa/convert_pkcs12_keypair.go @@ -30,7 +30,11 @@ func (r *FastRSA) ConvertPKCS12ToKeyPair(pkcs12, password string) (*PKCS12KeyPai certificateEncoded := "" if certificate != nil { - certificateEncoded = string(encodeCertificate(certificate)) + certificateBytes, err := encodeCertificate(certificate) + if err != nil { + return nil, err + } + certificateEncoded = string(certificateBytes) } keyPair = &PKCS12KeyPair{ PrivateKey: string(privateKey), diff --git a/rsa/encrypt_privatekey.go b/rsa/encrypt_privatekey.go index 201a12e..0ad4fc2 100644 --- a/rsa/encrypt_privatekey.go +++ b/rsa/encrypt_privatekey.go @@ -3,7 +3,6 @@ package rsa import ( "crypto/rand" "crypto/x509" - "encoding/pem" ) func (r *FastRSA) EncryptPrivateKey(privateKey, password, cipherName string) (string, error) { @@ -15,11 +14,13 @@ func (r *FastRSA) EncryptPrivateKey(privateKey, password, cipherName string) (st // TODO should be valid choose custom marshal pemBytes := x509.MarshalPKCS1PrivateKey(privateKeyCertKeyBase) - block, err := x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", pemBytes, []byte(password), getPEMCipher(cipherName)) + block, err := x509.EncryptPEMBlock(rand.Reader, getPrivateKeyHeaderByType(PrivateKeyFormatTypePKCS1), pemBytes, []byte(password), getPEMCipher(cipherName)) if err != nil { return "", err } - output := pem.EncodeToMemory(block) - - return string(output), nil + output, err := encodePem(block) + if err != nil { + return "", err + } + return string(output), err }