Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor pki code #65

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/pki/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (o *option) Validate() error {

func (o *option) Execute(_ context.Context) error {
if o.user != "" {
clientCSR, err := pki.CreateClientCSR(o.directory, o.domain, o.user)
clientCSR, err := pki.HandleCreateClientCSR(o.directory, o.domain, o.user)
if err != nil {
return fmt.Errorf("failed to create client csr: %w", err)
}
Expand All @@ -82,7 +82,7 @@ func (o *option) Execute(_ context.Context) error {
return nil
}

rootCA, err := pki.CreateAuraeRootCA(o.directory, o.domain)
rootCA, err := pki.HandleCreateAuraeRootCA(o.directory, o.domain)
if err != nil {
return fmt.Errorf("failed to create aurae root ca: %w", err)
}
Expand Down
247 changes: 158 additions & 89 deletions pkg/pki/pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ import (
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"time"
)

Expand All @@ -49,16 +47,145 @@ type Certificate struct {
PrivateKey string `json:"key" yaml:"key"`
}

func (c *Certificate) GetCertificate() (*x509.Certificate, error) {
crtPem, _ := pem.Decode([]byte(c.Certificate))
if crtPem == nil || crtPem.Type != "CERTIFICATE" {
return nil, fmt.Errorf("failed to decode certificate")
}

crt, err := x509.ParseCertificate(crtPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}

return crt, nil
}

func (c *Certificate) GetCertAsString() string {
return c.Certificate
}

func (c *Certificate) GetPrivateKey() (*rsa.PrivateKey, error) {
pkPem, _ := pem.Decode([]byte(c.PrivateKey))
if pkPem == nil || pkPem.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode private key")
}

pk, err := x509.ParsePKCS1PrivateKey(pkPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}

return pk, nil
}

func (c *Certificate) GetPrivateKeyAsString() string {
return c.PrivateKey
}

func (c *Certificate) WriteCertificateToFile(path, filename string) error {
err := createFile(path, filename, c.GetCertAsString())
if err != nil {
return err
}
return nil
}

func (c *Certificate) WritePrivateKeyToFile(path, filename string) error {
err := createFile(path, filename, c.GetPrivateKeyAsString())
if err != nil {
return err
}
return nil
}

type CertificateRequest struct {
CSR string `json:"csr" yaml:"csr"`
PrivateKey string `json:"key" yaml:"key"`
User string `json:"user" yaml:"user"`
}

func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
func (c *CertificateRequest) GetCsr() (*x509.CertificateRequest, error) {
csrPem, _ := pem.Decode([]byte(c.CSR))
if csrPem == nil || csrPem.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("failed to decode certificate request")
}

csr, err := x509.ParseCertificateRequest(csrPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate request: %w", err)
}

return csr, nil
}

func (c *CertificateRequest) GetCsrAsString() string {
return c.CSR
}

func (c *CertificateRequest) GetPrivateKey() (*rsa.PrivateKey, error) {
pkPem, _ := pem.Decode([]byte(c.PrivateKey))
if pkPem == nil || pkPem.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode private key")
}

pk, err := x509.ParsePKCS1PrivateKey(pkPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}

return pk, nil
}

func (c *CertificateRequest) GetPrivateKeyAsString() string {
return c.PrivateKey
}

func (c *CertificateRequest) WriteCsrToFile(path, filename string) error {
err := createFile(path, filename, c.GetCsrAsString())
if err != nil {
return err
}
return nil
}

func (c *CertificateRequest) WritePrivateKeyToFile(path, filename string) error {
err := createFile(path, filename, c.GetPrivateKeyAsString())
if err != nil {
return err
}
return nil
}

func HandleCreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
crtPem, keyPem, err := createCA(domainName)
if err != nil {
return nil, err
}

ca := &Certificate{
Certificate: string(crtPem),
PrivateKey: string(keyPem),
}

if path != "" {
err = ca.WriteCertificateToFile(path, "ca.crt")
if err != nil {
return ca, err
}
err = ca.WritePrivateKeyToFile(path, "ca.key")
if err != nil {
return ca, err
}
}

return ca, nil
}

func createCA(domainName string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate private key: %w", err)
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}

subj := pkix.Name{
Expand All @@ -75,7 +202,7 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate serial number: %w", err)
return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
}

template := x509.Certificate{
Expand Down Expand Up @@ -107,7 +234,7 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {

crtBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to create certificate: %w", err)
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
}

crtPem := pem.EncodeToMemory(&pem.Block{
Expand All @@ -117,28 +244,42 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {

keyPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
Bytes: crtBytes,
})

ca := &Certificate{
Certificate: string(crtPem),
PrivateKey: string(keyPem),
return crtPem, keyPem, nil
}

func HandleCreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
csrPem, keyPem, err := createClientCSR(domain, user)
if err != nil {
return &CertificateRequest{}, err
}

csr := &CertificateRequest{
CSR: string(csrPem),
PrivateKey: string(keyPem),
User: user,
}

if path != "" {
err = createCAFiles(path, ca)
err = csr.WriteCsrToFile(path, fmt.Sprintf("client.%s.csr", csr.User))
if err != nil {
return ca, err
return csr, err
}
err = csr.WritePrivateKeyToFile(path, fmt.Sprintf("client.%s.key", csr.User))
if err != nil {
return csr, err
}
}

return ca, nil
return csr, nil
}

func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
func createClientCSR(domain, user string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return &CertificateRequest{}, fmt.Errorf("failed to generate private key: %w", err)
return []byte{}, []byte{}, fmt.Errorf("failed to generate private key: %w", err)
}

subj := pkix.Name{
Expand All @@ -158,7 +299,7 @@ func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {

csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, priv)
if err != nil {
return &CertificateRequest{}, fmt.Errorf("could not create certificate request: %w", err)
return []byte{}, []byte{}, fmt.Errorf("could not create certificate request: %w", err)
}

csrPem := pem.EncodeToMemory(&pem.Block{
Expand All @@ -171,77 +312,5 @@ func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})

csr := &CertificateRequest{
CSR: string(csrPem),
PrivateKey: string(keyPem),
User: user,
}

if path != "" {
err = createCsrFiles(path, csr)
if err != nil {
return csr, err
}
}

return csr, nil
}

func createCAFiles(path string, ca *Certificate) error {
path = filepath.Clean(path)
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}

crtPath := filepath.Join(path, "ca.crt")
keyPath := filepath.Join(path, "ca.key")

err = writeStringToFile(crtPath, ca.Certificate)
if err != nil {
return err
}

err = writeStringToFile(keyPath, ca.PrivateKey)
if err != nil {
return err
}
return nil
}

func createCsrFiles(path string, ca *CertificateRequest) error {
path = filepath.Clean(path)
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}

csrPath := filepath.Join(path, fmt.Sprintf("client.%s.csr", ca.User))
keyPath := filepath.Join(path, fmt.Sprintf("client.%s.key", ca.User))

err = writeStringToFile(csrPath, ca.CSR)
if err != nil {
return err
}

err = writeStringToFile(keyPath, ca.PrivateKey)
if err != nil {
return err
}
return nil
}

func writeStringToFile(p string, s string) error {
f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", p, err)
}
defer f.Close()

_, err = f.WriteString(s)
if err != nil {
return fmt.Errorf("failed to write file %s: %w", p, err)
}

return nil
return csrPem, keyPem, nil
}
10 changes: 5 additions & 5 deletions pkg/pki/pki_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestCreateAuraeRootCA(t *testing.T) {
t.Run("createAuraeCA", func(t *testing.T) {
domainName := "unsafe.aurae.io"

auraeCa, err := CreateAuraeRootCA("", "unsafe.aurae.io")
auraeCa, err := HandleCreateAuraeRootCA("", "unsafe.aurae.io")
if err != nil {
t.Errorf("could not create auraeCA")
}
Expand All @@ -64,7 +64,7 @@ func TestCreateAuraeRootCA(t *testing.T) {
path := "_tmp/pki"
domainName := "unsafe.aurae.io"

_, err := CreateAuraeRootCA(path, domainName)
_, err := HandleCreateAuraeRootCA(path, domainName)
if err != nil {
t.Errorf("could not create auraeCA")
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestCreateAuraeRootCA(t *testing.T) {

func TestCreateCSR(t *testing.T) {
t.Run("createCSR", func(t *testing.T) {
clientCsr, err := CreateClientCSR("", "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR("", "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could not create csr")
}
Expand All @@ -121,7 +121,7 @@ func TestCreateCSR(t *testing.T) {

t.Run("createCSR with local files", func(t *testing.T) {
path := "_tmp/pki"
clientCsr, err := CreateClientCSR(path, "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR(path, "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could not create csr")
}
Expand Down Expand Up @@ -224,7 +224,7 @@ func TestCreateCSR(t *testing.T) {
}

// Genenerate a new CSR
clientCsr, err := CreateClientCSR("", "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR("", "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could create csr")
}
Expand Down
Loading