Skip to content

Commit

Permalink
split certificate match function from validate function
Browse files Browse the repository at this point in the history
  • Loading branch information
kayra1 committed Apr 3, 2024
1 parent 221f27b commit 3d96fe0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 23 deletions.
5 changes: 4 additions & 1 deletion internal/certdb/certdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ func (db *CertificateRequestsRepository) Create(csr string) (int64, error) {
// Update adds a new cert to the given CSR in the repository.
// The given certificate must share the public key of the CSR and must be valid.
func (db *CertificateRequestsRepository) Update(csr string, cert string) (int64, error) {
if err := ValidateCertificate(cert, csr); err != nil {
if err := ValidateCertificate(cert); err != nil {
return 0, err
}
if err := CertificateMatchesCSR(cert, csr); err != nil {
return 0, err
}
result, err := db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.table), cert, csr)
Expand Down
38 changes: 24 additions & 14 deletions internal/certdb/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
// ValidateCertificateRequest validates the given CSR string to the following:
// The string must be a valid PEM string, and should be of type CERTIFICATE REQUEST
// The PEM string should be able to be parsed into a x509 Certificate Request
func ValidateCertificateRequest(csrString string) error {
block, _ := pem.Decode([]byte(csrString))
func ValidateCertificateRequest(csr string) error {
block, _ := pem.Decode([]byte(csr))
if block == nil {
return errors.New("PEM Certificate Request string not found or malformed")
}
Expand All @@ -26,29 +26,39 @@ func ValidateCertificateRequest(csrString string) error {
}

// ValidateCertificate validates the given Cert string to the following:
// The given CSR must pass the validation provided by ValidateCertificateRequest
// The cert string must be a valid PEM string, and should be of type CERTIFICATE
// The PEM string should be able to be parsed into a x509 Certificate
// The given cert and CSR must share the same public key
func ValidateCertificate(certString string, csrString string) error {
if err := ValidateCertificateRequest(csrString); err != nil {
return err
}
csrBlock, _ := pem.Decode([]byte(csrString))
csr, _ := x509.ParseCertificateRequest(csrBlock.Bytes)
certBlock, _ := pem.Decode([]byte(certString))
func ValidateCertificate(cert string) error {
certBlock, _ := pem.Decode([]byte(cert))
if certBlock == nil {
return errors.New("PEM Certificate string not found or malformed")
}
if certBlock.Type != "CERTIFICATE" {
return errors.New("given PEM string not a certificate")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
_, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return err
}
certKey := cert.PublicKey.(*rsa.PublicKey)
csrKey := csr.PublicKey.(*rsa.PublicKey)
return nil
}

// CertificateMatchesCSR makes sure that the given certificate and CSR match.
// The given CSR and Cert must pass their respective validation functions
// The given cert and CSR must share the same public key
func CertificateMatchesCSR(cert string, csr string) error {
if err := ValidateCertificateRequest(csr); err != nil {
return err
}
if err := ValidateCertificate(cert); err != nil {
return err
}
csrBlock, _ := pem.Decode([]byte(csr))
parsedCSR, _ := x509.ParseCertificateRequest(csrBlock.Bytes)
certBlock, _ := pem.Decode([]byte(cert))
parsedCERT, _ := x509.ParseCertificate(certBlock.Bytes)
certKey := parsedCERT.PublicKey.(*rsa.PublicKey)
csrKey := parsedCSR.PublicKey.(*rsa.PublicKey)
if !csrKey.Equal(certKey) {
return errors.New("certificate does not match CSR")
}
Expand Down
51 changes: 43 additions & 8 deletions internal/certdb/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func TestCertValidationSuccess(t *testing.T) {

for i, c := range cases {
t.Run(fmt.Sprintf("ValidCert%d", i), func(t *testing.T) {
if err := certdb.ValidateCertificate(c, ValidCSR2); err != nil {
if err := certdb.ValidateCertificate(c); err != nil {
t.Errorf("Couldn't verify valid Cert: %s", err)
}
})
Expand All @@ -162,33 +162,68 @@ func TestCertValidationFail(t *testing.T) {
var wrongPemTypeErr = "given PEM string not a certificate"
var InvalidCert = strings.ReplaceAll(ValidCert2, "M", "i")
var InvalidCertErr = "x509: malformed certificate"
var certificateDoesNotMatchErr = "certificate does not match CSR"

cases := []struct {
inputCSR string
inputCert string
expectedErr string
}{
{
inputCSR: ValidCSR2,
inputCert: wrongCertString,
expectedErr: wrongCertStringErr,
},
{
inputCSR: ValidCSR2,
inputCert: ValidCertWithoutWhitespace,
expectedErr: ValidCertWithoutWhitespaceErr,
},
{
inputCSR: ValidCSR2,
inputCert: wrongPemType,
expectedErr: wrongPemTypeErr,
},
{
inputCSR: ValidCSR2,
inputCert: InvalidCert,
expectedErr: InvalidCertErr,
},
}

for i, c := range cases {
t.Run(fmt.Sprintf("InvalidCert%d", i), func(t *testing.T) {
err := certdb.ValidateCertificate(c.inputCert)
if err.Error() != c.expectedErr {
t.Errorf("Expected error not found:\nReceived: %s\n Expected: %s", err, c.expectedErr)
}
})
}
}

func TestCertificateMatchesCSRSuccess(t *testing.T) {
cases := []struct {
inputCSR string
inputCert string
}{
{
inputCSR: ValidCSR2,
inputCert: ValidCert2,
},
}

for i, c := range cases {
t.Run(fmt.Sprintf("InvalidCert%d", i), func(t *testing.T) {
err := certdb.CertificateMatchesCSR(c.inputCert, c.inputCSR)
if err != nil {
t.Errorf("Certificate did not match when it should have")
}
})
}
}

func TestCertificateMatchesCSRFail(t *testing.T) {
var certificateDoesNotMatchErr = "certificate does not match CSR"

cases := []struct {
inputCSR string
inputCert string
expectedErr string
}{
{
inputCSR: ValidCSR1,
inputCert: ValidCert2,
Expand All @@ -198,7 +233,7 @@ func TestCertValidationFail(t *testing.T) {

for i, c := range cases {
t.Run(fmt.Sprintf("InvalidCert%d", i), func(t *testing.T) {
err := certdb.ValidateCertificate(c.inputCert, c.inputCSR)
err := certdb.CertificateMatchesCSR(c.inputCert, c.inputCSR)
if err.Error() != c.expectedErr {
t.Errorf("Expected error not found:\nReceived: %s\n Expected: %s", err, c.expectedErr)
}
Expand Down

0 comments on commit 3d96fe0

Please sign in to comment.