Skip to content

Commit

Permalink
refactor pki code
Browse files Browse the repository at this point in the history
  • Loading branch information
voigt committed Mar 24, 2023
1 parent 1f1f075 commit 508fc7e
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 96 deletions.
4 changes: 2 additions & 2 deletions cmd/pki/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (o *option) Validate() error {

func (o *option) Execute(_ context.Context) error {
if o.user != "" {
clientCSR, err := pki.CreateClientCSR(o.directory, o.domain, o.user)
clientCSR, err := pki.HandleCreateClientCSR(o.directory, o.domain, o.user)
if err != nil {
return fmt.Errorf("failed to create client csr: %w", err)
}
Expand All @@ -82,7 +82,7 @@ func (o *option) Execute(_ context.Context) error {
return nil
}

rootCA, err := pki.CreateAuraeRootCA(o.directory, o.domain)
rootCA, err := pki.HandleCreateAuraeRootCA(o.directory, o.domain)
if err != nil {
return fmt.Errorf("failed to create aurae root ca: %w", err)
}
Expand Down
241 changes: 152 additions & 89 deletions pkg/pki/pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ import (
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"time"
)

Expand All @@ -49,16 +47,142 @@ type Certificate struct {
PrivateKey string `json:"key" yaml:"key"`
}

func (c *Certificate) GetCertificate() (*x509.Certificate, error) {
crtPem, _ := pem.Decode([]byte(c.Certificate))
if crtPem == nil || crtPem.Type != "CERTIFICATE" {
return nil, fmt.Errorf("failed to decode certificate")
}

crt, err := x509.ParseCertificate(crtPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}

return crt, nil
}

func (c *Certificate) GetCertAsString() string {
return c.Certificate
}

func (c *Certificate) GetPrivateKey() (*rsa.PrivateKey, error) {
pkPem, _ := pem.Decode([]byte(c.PrivateKey))
if pkPem == nil || pkPem.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode private key")
}

pk, err := x509.ParsePKCS1PrivateKey(pkPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}

return pk, nil
}

func (c *Certificate) GetPrivateKeyAsString() string {
return c.PrivateKey
}

func (c *Certificate) WriteCertificateToFile(path, filename string) error {
err := createFile(path, filename, c.GetCertAsString())
if err != nil {
return err
}
return nil
}

func (c *Certificate) WritePrivateKeyToFile(path, filename string) error {
err := createFile(path, filename, c.GetPrivateKeyAsString())
if err != nil {
return err
}
return nil
}

type CertificateRequest struct {
CSR string `json:"csr" yaml:"csr"`
PrivateKey string `json:"key" yaml:"key"`
User string `json:"user" yaml:"user"`
}

func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
func (c *CertificateRequest) GetCsr() (*x509.CertificateRequest, error) {
csrPem, _ := pem.Decode([]byte(c.CSR))
if csrPem == nil || csrPem.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("failed to decode certificate request")
}

csr, err := x509.ParseCertificateRequest(csrPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate request: %w", err)
}

return csr, nil
}

func (c *CertificateRequest) GetCsrAsString() string {
return c.CSR
}

func (c *CertificateRequest) GetPrivateKey() (*rsa.PrivateKey, error) {
pkPem, _ := pem.Decode([]byte(c.PrivateKey))
if pkPem == nil || pkPem.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode private key")
}

pk, err := x509.ParsePKCS1PrivateKey(pkPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}

return pk, nil
}

func (c *CertificateRequest) GetPrivateKeyAsString() string {
return c.PrivateKey
}

func (c *CertificateRequest) WriteCsrToFile(path, filename string) error {
err := createFile(path, filename, c.GetCsrAsString())
if err != nil {
return err
}
return nil
}

func (c *CertificateRequest) WritePrivateKeyToFile(path, filename string) error {
err := createFile(path, filename, c.GetPrivateKeyAsString())
if err != nil {
return err
}
return nil
}

func HandleCreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
crtPem, keyPem, err := createCA(domainName)
if err != nil {
return nil, err
}

ca := &Certificate{
Certificate: string(crtPem),
PrivateKey: string(keyPem),
}

if path != "" {
err = ca.WriteCertificateToFile(path, "ca.crt")
err = ca.WritePrivateKeyToFile(path, "ca.key")
if err != nil {
return ca, err
}
}

return ca, nil
}

func createCA(domainName string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate private key: %w", err)
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}

subj := pkix.Name{
Expand All @@ -75,7 +199,7 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate serial number: %w", err)
return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
}

template := x509.Certificate{
Expand Down Expand Up @@ -107,7 +231,7 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {

crtBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to create certificate: %w", err)
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
}

crtPem := pem.EncodeToMemory(&pem.Block{
Expand All @@ -117,28 +241,39 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {

keyPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
Bytes: crtBytes,
})

ca := &Certificate{
Certificate: string(crtPem),
PrivateKey: string(keyPem),
return crtPem, keyPem, nil
}

func HandleCreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
csrPem, keyPem, err := createClientCSR(domain, user)
if err != nil {
return &CertificateRequest{}, err
}

csr := &CertificateRequest{
CSR: string(csrPem),
PrivateKey: string(keyPem),
User: user,
}

if path != "" {
err = createCAFiles(path, ca)
err = csr.WriteCsrToFile(path, fmt.Sprintf("client.%s.csr", csr.User))
err = csr.WritePrivateKeyToFile(path, fmt.Sprintf("client.%s.key", csr.User))
if err != nil {
return ca, err
return csr, err
}
}

return ca, nil
return csr, nil
}

func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
func createClientCSR(domain, user string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return &CertificateRequest{}, fmt.Errorf("failed to generate private key: %w", err)
return []byte{}, []byte{}, fmt.Errorf("failed to generate private key: %w", err)
}

subj := pkix.Name{
Expand All @@ -158,7 +293,7 @@ func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {

csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, priv)
if err != nil {
return &CertificateRequest{}, fmt.Errorf("could not create certificate request: %w", err)
return []byte{}, []byte{}, fmt.Errorf("could not create certificate request: %w", err)
}

csrPem := pem.EncodeToMemory(&pem.Block{
Expand All @@ -171,77 +306,5 @@ func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})

csr := &CertificateRequest{
CSR: string(csrPem),
PrivateKey: string(keyPem),
User: user,
}

if path != "" {
err = createCsrFiles(path, csr)
if err != nil {
return csr, err
}
}

return csr, nil
}

func createCAFiles(path string, ca *Certificate) error {
path = filepath.Clean(path)
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}

crtPath := filepath.Join(path, "ca.crt")
keyPath := filepath.Join(path, "ca.key")

err = writeStringToFile(crtPath, ca.Certificate)
if err != nil {
return err
}

err = writeStringToFile(keyPath, ca.PrivateKey)
if err != nil {
return err
}
return nil
}

func createCsrFiles(path string, ca *CertificateRequest) error {
path = filepath.Clean(path)
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}

csrPath := filepath.Join(path, fmt.Sprintf("client.%s.csr", ca.User))
keyPath := filepath.Join(path, fmt.Sprintf("client.%s.key", ca.User))

err = writeStringToFile(csrPath, ca.CSR)
if err != nil {
return err
}

err = writeStringToFile(keyPath, ca.PrivateKey)
if err != nil {
return err
}
return nil
}

func writeStringToFile(p string, s string) error {
f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", p, err)
}
defer f.Close()

_, err = f.WriteString(s)
if err != nil {
return fmt.Errorf("failed to write file %s: %w", p, err)
}

return nil
return csrPem, keyPem, nil
}
10 changes: 5 additions & 5 deletions pkg/pki/pki_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestCreateAuraeRootCA(t *testing.T) {
t.Run("createAuraeCA", func(t *testing.T) {
domainName := "unsafe.aurae.io"

auraeCa, err := CreateAuraeRootCA("", "unsafe.aurae.io")
auraeCa, err := HandleCreateAuraeRootCA("", "unsafe.aurae.io")
if err != nil {
t.Errorf("could not create auraeCA")
}
Expand All @@ -64,7 +64,7 @@ func TestCreateAuraeRootCA(t *testing.T) {
path := "_tmp/pki"
domainName := "unsafe.aurae.io"

_, err := CreateAuraeRootCA(path, domainName)
_, err := HandleCreateAuraeRootCA(path, domainName)
if err != nil {
t.Errorf("could not create auraeCA")
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestCreateAuraeRootCA(t *testing.T) {

func TestCreateCSR(t *testing.T) {
t.Run("createCSR", func(t *testing.T) {
clientCsr, err := CreateClientCSR("", "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR("", "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could not create csr")
}
Expand All @@ -121,7 +121,7 @@ func TestCreateCSR(t *testing.T) {

t.Run("createCSR with local files", func(t *testing.T) {
path := "_tmp/pki"
clientCsr, err := CreateClientCSR(path, "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR(path, "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could not create csr")
}
Expand Down Expand Up @@ -224,7 +224,7 @@ func TestCreateCSR(t *testing.T) {
}

// Genenerate a new CSR
clientCsr, err := CreateClientCSR("", "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR("", "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could create csr")
}
Expand Down
Loading

0 comments on commit 508fc7e

Please sign in to comment.