Skip to content

Commit

Permalink
Fixed header by key type
Browse files Browse the repository at this point in the history
  • Loading branch information
jerson committed Dec 18, 2021
1 parent 64cc9c4 commit 18c4c41
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 48 deletions.
95 changes: 64 additions & 31 deletions rsa/codec.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rsa

import (
"bufio"
"bytes"
"crypto"
"crypto/md5"
Expand Down Expand Up @@ -30,27 +31,19 @@ const (
PrivateKeyFormatTypePKCS8
)

func getPrivateKeyFormatType(format string) PrivateKeyFormatType {
switch format {
case "pkcs8":
return PrivateKeyFormatTypePKCS8
case "pkcs1":
return PrivateKeyFormatTypePKCS1
default:
return PrivateKeyFormatTypePKCS1
}
}
type HeaderPublicKeyType string

func getPublicKeyFormatType(format string) PublicKeyFormatType {
switch format {
case "pkix":
return PublicKeyFormatTypePKIX
case "pkcs1":
return PublicKeyFormatTypePKCS1
default:
return PublicKeyFormatTypePKCS1
}
}
const (
HeaderPublicKeyPKCS1 = HeaderPublicKeyType("RSA PUBLIC KEY")
HeaderPublicKeyPKIX = HeaderPublicKeyType("PUBLIC KEY")
)

type HeaderPrivateKeyType string

const (
HeaderPrivateKeyPKCS1 = HeaderPrivateKeyType("RSA PRIVATE KEY")
HeaderPrivateKeyPKCS8 = HeaderPrivateKeyType("PRIVATE KEY")
)

func getSaltLength(length string) int {
switch length {
Expand Down Expand Up @@ -117,6 +110,28 @@ func getPEMCipher(name string) x509.PEMCipher {
}
}

func getPublicKeyHeaderByType(formatType PublicKeyFormatType) string {
switch formatType {
case PublicKeyFormatTypePKIX:
return string(HeaderPublicKeyPKIX)
case PublicKeyFormatTypePKCS1:
fallthrough
default:
return string(HeaderPublicKeyPKCS1)
}
}

func getPrivateKeyHeaderByType(formatType PrivateKeyFormatType) string {
switch formatType {
case PrivateKeyFormatTypePKCS8:
return string(HeaderPrivateKeyPKCS8)
case PrivateKeyFormatTypePKCS1:
fallthrough
default:
return string(HeaderPrivateKeyPKCS1)
}
}

func encodePublicKey(publicKey interface{}, formatType PublicKeyFormatType) ([]byte, error) {

var pemBytes []byte
Expand All @@ -140,12 +155,12 @@ func encodePublicKey(publicKey interface{}, formatType PublicKeyFormatType) ([]b
}
break
}
return pem.EncodeToMemory(
return encodePem(
&pem.Block{
Type: "PUBLIC KEY",
Type: getPublicKeyHeaderByType(formatType),
Bytes: pemBytes,
},
), nil
)
}

func encodePrivateKey(privateKey interface{}, formatType PrivateKeyFormatType) ([]byte, error) {
Expand Down Expand Up @@ -173,12 +188,25 @@ func encodePrivateKey(privateKey interface{}, formatType PrivateKeyFormatType) (
break
}

return pem.EncodeToMemory(
return encodePem(
&pem.Block{
Type: "RSA PRIVATE KEY",
Type: getPrivateKeyHeaderByType(formatType),
Bytes: pemBytes,
},
), nil
)
}

// encodePem thanks to @patachi
func encodePem(b *pem.Block) ([]byte, error) {
buf := new(bytes.Buffer)
w := bufio.NewWriter(buf)
if err := pem.Encode(w, b); err != nil {
return nil, err
}
if err := w.Flush(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func publicFromPrivate(privateKey interface{}) (interface{}, error) {
Expand All @@ -191,9 +219,8 @@ func publicFromPrivate(privateKey interface{}) (interface{}, error) {

}

func encodeCertificate(certificate *x509.Certificate) []byte {

return pem.EncodeToMemory(
func encodeCertificate(certificate *x509.Certificate) ([]byte, error) {
return encodePem(
&pem.Block{
Type: "CERTIFICATE",
Bytes: certificate.Raw,
Expand All @@ -211,8 +238,14 @@ func encodeToPEMBase64(input []byte) ([]byte, error) {
if _, err := b64.Write(input); err != nil {
return nil, err
}
b64.Close()
breaker.Close()
err := b64.Close()
if err != nil {
return nil, err
}
err = breaker.Close()
if err != nil {
return nil, err
}

return out.Bytes(), nil

Expand Down
47 changes: 36 additions & 11 deletions rsa/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rsa

import (
"crypto/rsa"
"fmt"
"golang.org/x/crypto/pkcs12"

"crypto/x509"
Expand All @@ -12,36 +13,60 @@ import (

func (r *FastRSA) readPrivateKey(privateKey string) (*rsa.PrivateKey, error) {

privateBlock, _ := pem.Decode([]byte(privateKey))
if privateBlock == nil {
return nil, errors.New("invalid private key")
block, _ := pem.Decode([]byte(privateKey))
if block == nil {
return nil, fmt.Errorf("invalid private key")
}

if privateKeyCert, err := x509.ParsePKCS1PrivateKey(privateBlock.Bytes); err == nil {
switch block.Type {
case string(HeaderPrivateKeyPKCS1):
if privateKeyCert, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return privateKeyCert, nil
}
case string(HeaderPrivateKeyPKCS8):
if privateKeyCert, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
return privateKeyCert.(*rsa.PrivateKey), nil
}
}

// TODO remove this in the future because we need to use block.type instead
if privateKeyCert, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return privateKeyCert, nil
}
if privateKeyCert, err := x509.ParsePKCS8PrivateKey(privateBlock.Bytes); err == nil {
if privateKeyCert, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
return privateKeyCert.(*rsa.PrivateKey), nil
}

return nil, errors.New("x509: unknown format")
return nil, fmt.Errorf("x509: unknown format for privateKey: %s", block.Type)
}

func (r *FastRSA) readPublicKey(publicKey string) (*rsa.PublicKey, error) {

publicBlock, _ := pem.Decode([]byte(publicKey))
if publicBlock == nil {
block, _ := pem.Decode([]byte(publicKey))
if block == nil {
return nil, errors.New("invalid public key")
}

if publicKeyCert, err := x509.ParsePKCS1PublicKey(publicBlock.Bytes); err == nil {
switch block.Type {
case string(HeaderPublicKeyPKCS1):
if publicKeyCert, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil {
return publicKeyCert, nil
}
case string(HeaderPublicKeyPKIX):
if publicKeyCert, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil {
return publicKeyCert.(*rsa.PublicKey), nil
}
}

// TODO remove this in the future because we need to use block.type instead
if publicKeyCert, err := x509.ParsePKCS1PublicKey(block.Bytes); err == nil {
return publicKeyCert, nil
}
if publicKeyCert, err := x509.ParsePKIXPublicKey(publicBlock.Bytes); err == nil {
if publicKeyCert, err := x509.ParsePKIXPublicKey(block.Bytes); err == nil {
return publicKeyCert.(*rsa.PublicKey), nil
}

return nil, errors.New("x509: unknown format")
return nil, fmt.Errorf("x509: unknown format for publicKey: %s", block.Type)
}

func (r *FastRSA) readPKCS12(data, password string) (interface{}, *x509.Certificate, error) {
Expand Down
6 changes: 5 additions & 1 deletion rsa/convert_pkcs12_keypair.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ func (r *FastRSA) ConvertPKCS12ToKeyPair(pkcs12, password string) (*PKCS12KeyPai

certificateEncoded := ""
if certificate != nil {
certificateEncoded = string(encodeCertificate(certificate))
certificateBytes, err := encodeCertificate(certificate)
if err != nil {
return nil, err
}
certificateEncoded = string(certificateBytes)
}
keyPair = &PKCS12KeyPair{
PrivateKey: string(privateKey),
Expand Down
11 changes: 6 additions & 5 deletions rsa/encrypt_privatekey.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package rsa
import (
"crypto/rand"
"crypto/x509"
"encoding/pem"
)

func (r *FastRSA) EncryptPrivateKey(privateKey, password, cipherName string) (string, error) {
Expand All @@ -15,11 +14,13 @@ func (r *FastRSA) EncryptPrivateKey(privateKey, password, cipherName string) (st

// TODO should be valid choose custom marshal
pemBytes := x509.MarshalPKCS1PrivateKey(privateKeyCertKeyBase)
block, err := x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", pemBytes, []byte(password), getPEMCipher(cipherName))
block, err := x509.EncryptPEMBlock(rand.Reader, getPrivateKeyHeaderByType(PrivateKeyFormatTypePKCS1), pemBytes, []byte(password), getPEMCipher(cipherName))
if err != nil {
return "", err
}
output := pem.EncodeToMemory(block)

return string(output), nil
output, err := encodePem(block)
if err != nil {
return "", err
}
return string(output), err
}

0 comments on commit 18c4c41

Please sign in to comment.