diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3997bea --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.db \ No newline at end of file diff --git a/cmd/gocert/main.go b/cmd/gocert/main.go index ed53fd7..42d1148 100644 --- a/cmd/gocert/main.go +++ b/cmd/gocert/main.go @@ -1,5 +1,5 @@ package main func main() { - // ... + // ListenAndServe } diff --git a/go.mod b/go.mod index 50e6ec0..103d8a4 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/canonical/gocert go 1.22.1 + +require github.com/mattn/go-sqlite3 v1.14.22 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e8d092a --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/internal/certdb/certdb.go b/internal/certdb/certdb.go new file mode 100644 index 0000000..297db38 --- /dev/null +++ b/internal/certdb/certdb.go @@ -0,0 +1,136 @@ +// Package certdb provides a simplistic ORM to communicate with an SQL database for storage +package certdb + +import ( + "database/sql" + "fmt" + + _ "github.com/mattn/go-sqlite3" +) + +const queryCreateTable = "CREATE TABLE IF NOT EXISTS %s (CSR VARCHAR PRIMARY KEY UNIQUE NOT NULL, Certificate VARCHAR DEFAULT '')" + +const queryGetAllCSRs = "SELECT rowid, * FROM %s" +const queryGetCSR = "SELECT rowid, * FROM %s WHERE CSR=?" +const queryCreateCSR = "INSERT INTO %s (CSR) VALUES (?)" +const queryUpdateCSR = "UPDATE %s SET Certificate=? WHERE CSR=?" +const queryDeleteCSR = "DELETE FROM %s WHERE CSR=?" + +// CertificateRequestRepository is the object used to communicate with the established repository. +type CertificateRequestsRepository struct { + table string + conn *sql.DB +} + +// A CertificateRequest struct represents an entry in the database. +// The object contains a Certificate Request, its matching Certificate if any, and the row ID. +type CertificateRequest struct { + ID int + CSR string + Certificate string +} + +// RetrieveAll gets every CertificateRequest entry in the table. +func (db *CertificateRequestsRepository) RetrieveAll() ([]CertificateRequest, error) { + rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.table)) + if err != nil { + return nil, err + } + + var allCsrs []CertificateRequest + defer rows.Close() + for rows.Next() { + var csr CertificateRequest + if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil { + return nil, err + } + allCsrs = append(allCsrs, csr) + } + return allCsrs, nil +} + +// Retrieve gets a given CSR from the repository. +// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object. +func (db *CertificateRequestsRepository) Retrieve(csr string) (CertificateRequest, error) { + var newCSR CertificateRequest + row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.table), csr) + if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil { + return newCSR, err + } + return newCSR, nil +} + +// Create creates a new entry in the repository. +// The given CSR must be valid and unique +func (db *CertificateRequestsRepository) Create(csr string) (int64, error) { + if err := ValidateCertificateRequest(csr); err != nil { + return 0, err + } + result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.table), csr) + if err != nil { + return 0, err + } + id, err := result.LastInsertId() + if err != nil { + return 0, err + } + return id, nil +} + +// Update adds a new cert to the given CSR in the repository. +// The given certificate must share the public key of the CSR and must be valid. +func (db *CertificateRequestsRepository) Update(csr string, cert string) (int64, error) { + if err := ValidateCertificate(cert); err != nil { + return 0, err + } + if err := CertificateMatchesCSR(cert, csr); err != nil { + return 0, err + } + result, err := db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.table), cert, csr) + if err != nil { + return 0, err + } + id, err := result.LastInsertId() + if err != nil { + return 0, err + } + return id, nil +} + +// Delete removes a CSR from the database alongside the certificate that may have been generated for it. +func (db *CertificateRequestsRepository) Delete(csr string) error { + _, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.table), csr) + if err != nil { + return err + } + return nil +} + +// Close closes the connection to the repository cleanly. +func (db *CertificateRequestsRepository) Close() error { + if db.conn == nil { + return nil + } + if err := db.conn.Close(); err != nil { + return err + } + return nil +} + +// NewCertificateRequestsRepository connects to a given table in a given database, +// stores the connection information and returns an object containing the information. +// The database path must be a valid file path or ":memory:". +// The table will be created if it doesn't exist in the format expected by the package. +func NewCertificateRequestsRepository(databasePath string, tableName string) (*CertificateRequestsRepository, error) { + conn, err := sql.Open("sqlite3", databasePath) + if err != nil { + return nil, err + } + if _, err := conn.Exec(fmt.Sprintf(queryCreateTable, tableName)); err != nil { + return nil, err + } + db := new(CertificateRequestsRepository) + db.conn = conn + db.table = tableName + return db, nil +} diff --git a/internal/certdb/certdb_test.go b/internal/certdb/certdb_test.go new file mode 100644 index 0000000..4dfe946 --- /dev/null +++ b/internal/certdb/certdb_test.go @@ -0,0 +1,134 @@ +package certdb_test + +import ( + "log" + "strings" + "testing" + + "github.com/canonical/gocert/internal/certdb" +) + +func TestConnect(t *testing.T) { + db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs") + if err != nil { + t.Fatalf("Can't connect to SQLite: %s", err) + } + db.Close() +} + +func TestEndToEnd(t *testing.T) { + db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck + if err != nil { + t.Fatalf("Couldn't complete NewCertificateRequestsRepository: %s", err) + } + defer db.Close() + + if _, err := db.Create(ValidCSR1); err != nil { + t.Fatalf("Couldn't complete Create: %s", err) + } + if _, err := db.Create(ValidCSR2); err != nil { + t.Fatalf("Couldn't complete Create: %s", err) + } + if _, err := db.Create(ValidCSR3); err != nil { + t.Fatalf("Couldn't complete Create: %s", err) + } + + res, err := db.RetrieveAll() + if err != nil { + t.Fatalf("Couldn't complete RetrieveAll: %s", err) + } + if len(res) != 3 { + t.Fatalf("One or more CSRs weren't found in DB") + } + retrievedCSR, err := db.Retrieve(ValidCSR1) + if err != nil { + t.Fatalf("Couldn't complete Retrieve: %s", err) + } + if retrievedCSR.CSR != ValidCSR1 { + t.Fatalf("The CSR from the database doesn't match the CSR that was given") + } + + if err = db.Delete(ValidCSR1); err != nil { + t.Fatalf("Couldn't complete Delete: %s", err) + } + res, _ = db.RetrieveAll() + if len(res) != 2 { + t.Fatalf("CSR's weren't deleted from the DB properly") + } + + _, err = db.Update(ValidCSR2, ValidCert2) + if err != nil { + t.Fatalf("Couldn't complete Update: %s", err) + } + retrievedCSR, _ = db.Retrieve(ValidCSR2) + if retrievedCSR.Certificate != ValidCert2 { + t.Fatalf("The certificate that was uploaded does not match the certificate that was given: Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, ValidCert2) + } +} + +func TestCreateFails(t *testing.T) { + db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs") //nolint:errcheck + defer db.Close() + + InvalidCSR := strings.ReplaceAll(ValidCSR1, "/", "+") + if _, err := db.Create(InvalidCSR); err == nil { + t.Fatalf("Expected error due to invalid CSR") + } + + db.Create(ValidCSR1) //nolint:errcheck + if _, err := db.Create(ValidCSR1); err == nil { + t.Fatalf("Expected error due to duplicate CSR") + } +} + +func TestUpdateFails(t *testing.T) { + db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck + defer db.Close() + + db.Create(ValidCSR1) //nolint:errcheck + db.Create(ValidCSR2) //nolint:errcheck + InvalidCert := strings.ReplaceAll(ValidCert2, "/", "+") + if _, err := db.Update(ValidCSR2, InvalidCert); err == nil { + t.Fatalf("Expected updating with invalid cert to fail") + } + if _, err := db.Update(ValidCSR1, ValidCert2); err == nil { + t.Fatalf("Expected updating with mismatched cert to fail") + } +} + +func TestRetrieve(t *testing.T) { + db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck + defer db.Close() + + db.Create(ValidCSR1) //nolint:errcheck + if _, err := db.Retrieve(ValidCSR2); err == nil { + t.Fatalf("Expected failure looking for nonexistent CSR") + } + +} + +func Example() { + db, err := certdb.NewCertificateRequestsRepository("./certs.db", "CertificateReq") + if err != nil { + log.Fatalln(err) + } + _, err = db.Create(ValidCSR2) + if err != nil { + log.Fatalln(err) + } + _, err = db.Update(ValidCSR2, ValidCert2) + if err != nil { + log.Fatalln(err) + } + entry, err := db.Retrieve(ValidCSR2) + if err != nil { + log.Fatalln(err) + } + if entry.Certificate != ValidCert2 { + log.Fatalln("Retrieved Certificate doesn't match Stored Certificate") + } + err = db.Close() + if err != nil { + log.Fatalln(err) + } +} diff --git a/internal/certdb/validation.go b/internal/certdb/validation.go new file mode 100644 index 0000000..9a487a4 --- /dev/null +++ b/internal/certdb/validation.go @@ -0,0 +1,66 @@ +package certdb + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" +) + +// ValidateCertificateRequest validates the given CSR string to the following: +// The string must be a valid PEM string, and should be of type CERTIFICATE REQUEST +// The PEM string should be able to be parsed into a x509 Certificate Request +func ValidateCertificateRequest(csr string) error { + block, _ := pem.Decode([]byte(csr)) + if block == nil { + return errors.New("PEM Certificate Request string not found or malformed") + } + if block.Type != "CERTIFICATE REQUEST" { + return errors.New("given PEM string not a certificate request") + } + _, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return err + } + return nil +} + +// ValidateCertificate validates the given Cert string to the following: +// The cert string must be a valid PEM string, and should be of type CERTIFICATE +// The PEM string should be able to be parsed into a x509 Certificate +func ValidateCertificate(cert string) error { + certBlock, _ := pem.Decode([]byte(cert)) + if certBlock == nil { + return errors.New("PEM Certificate string not found or malformed") + } + if certBlock.Type != "CERTIFICATE" { + return errors.New("given PEM string not a certificate") + } + _, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return err + } + return nil +} + +// CertificateMatchesCSR makes sure that the given certificate and CSR match. +// The given CSR and Cert must pass their respective validation functions +// The given cert and CSR must share the same public key +func CertificateMatchesCSR(cert string, csr string) error { + if err := ValidateCertificateRequest(csr); err != nil { + return err + } + if err := ValidateCertificate(cert); err != nil { + return err + } + csrBlock, _ := pem.Decode([]byte(csr)) + parsedCSR, _ := x509.ParseCertificateRequest(csrBlock.Bytes) + certBlock, _ := pem.Decode([]byte(cert)) + parsedCERT, _ := x509.ParseCertificate(certBlock.Bytes) + certKey := parsedCERT.PublicKey.(*rsa.PublicKey) + csrKey := parsedCSR.PublicKey.(*rsa.PublicKey) + if !csrKey.Equal(certKey) { + return errors.New("certificate does not match CSR") + } + return nil +} diff --git a/internal/certdb/validation_test.go b/internal/certdb/validation_test.go new file mode 100644 index 0000000..6e0940e --- /dev/null +++ b/internal/certdb/validation_test.go @@ -0,0 +1,242 @@ +package certdb_test + +import ( + "fmt" + "strings" + "testing" + + "github.com/canonical/gocert/internal/certdb" +) + +var ValidCSR1 string = `-----BEGIN CERTIFICATE REQUEST----- +MIICszCCAZsCAQAwFjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3 +DQEBAQUAA4IBDwAwggEKAoIBAQDC5KgrADpuOUPwSh0YLmpWF66VTcciIGC2HcGn +oJknL7pm5q9qhfWGIdvKKlIA6cBB32jPd0QcYDsx7+AvzEvBuO7mq7v2Q1sPU4Q+ +L0s2pLJges6/cnDWvk/p5eBjDLOqHhUNzpMUga9SgIod8yymTZm3eqQvt1ABdwTg +FzBs5QdSm2Ny1fEbbcRE+Rv5rqXyJb2isXSujzSuS22VqslDIyqnY5WaLg+pjZyR ++0j13ecJsdh6/MJMUZWheimV2Yv7SFtxzFwbzBMO9YFS098sy4F896eBHLNe9cUC ++d1JDtLaewlMogjHBHAxmP54dhe6vvc78anElKKP4hm5N5nlAgMBAAGgWDBWBgkq +hkiG9w0BCQ4xSTBHMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcD +AQYIKwYBBQUHAwIwFgYDVR0RBA8wDYILZXhhbXBsZS5jb20wDQYJKoZIhvcNAQEL +BQADggEBACP1VKEGVYKoVLMDJS+EZ0CPwIYWsO4xBXgK6atHe8WIChVn/8I7eo60 +cuMDiy4LR70G++xL1tpmYGRbx21r9d/shL2ehp9VdClX06qxlcGxiC/F8eThRuS5 +zHcdNqSVyMoLJ0c7yWHJahN5u2bn1Lov34yOEqGGpWCGF/gT1nEvM+p/v30s89f2 +Y/uPl4g3jpGqLCKTASWJDGnZLroLICOzYTVs5P3oj+VueSUwYhGK5tBnS2x5FHID +uMNMgwl0fxGMQZjrlXyCBhXBm1k6PmwcJGJF5LQ31c+5aTTMFU7SyZhlymctB8mS +y+ErBQsRpcQho6Ok+HTXQQUcx7WNcwI= +-----END CERTIFICATE REQUEST----- +` +var ValidCSR2 string = `-----BEGIN CERTIFICATE REQUEST----- +MIIC5zCCAc8CAQAwRzEWMBQGA1UEAwwNMTAuMTUyLjE4My41MzEtMCsGA1UELQwk +MzlhY2UxOTUtZGM1YS00MzJiLTgwOTAtYWZlNmFiNGI0OWNmMIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEAjM5Wz+HRtDveRzeDkEDM4ornIaefe8d8nmFi +pUat9qCU3U9798FR460DHjCLGxFxxmoRitzHtaR4ew5H036HlGB20yas/CMDgSUI +69DyAsyPwEJqOWBGO1LL50qXdl5/jOkO2voA9j5UsD1CtWSklyhbNhWMpYqj2ObW +XcaYj9Gx/TwYhw8xsJ/QRWyCrvjjVzH8+4frfDhBVOyywN7sq+I3WwCbyBBcN8uO +yae0b/q5+UJUiqgpeOAh/4Y7qI3YarMj4cm7dwmiCVjedUwh65zVyHtQUfLd8nFW +Kl9775mNBc1yicvKDU3ZB5hZ1MZtpbMBwaA1yMSErs/fh5KaXwIDAQABoFswWQYJ +KoZIhvcNAQkOMUwwSjBIBgNVHREEQTA/hwQKmLc1gjd2YXVsdC1rOHMtMC52YXVs +dC1rOHMtZW5kcG9pbnRzLnZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsMA0GCSqGSIb3 +DQEBCwUAA4IBAQCJt8oVDbiuCsik4N5AOJIT7jKsMb+j0mizwjahKMoCHdx+zv0V +FGkhlf0VWPAdEu3gHdJfduX88WwzJ2wBBUK38UuprAyvfaZfaYUgFJQNC6DH1fIa +uHYEhvNJBdFJHaBvW7lrSFi57fTA9IEPrB3m/XN3r2F4eoHnaJJqHZmMwqVHck87 +cAQXk3fvTWuikHiCHqqdSdjDYj/8cyiwCrQWpV245VSbOE0WesWoEnSdFXVUfE1+ +RSKeTRuuJMcdGqBkDnDI22myj0bjt7q8eqBIjTiLQLnAFnQYpcCrhc8dKU9IJlv1 +H9Hay4ZO9LRew3pEtlx2WrExw/gpUcWM8rTI +-----END CERTIFICATE REQUEST-----` + +var ValidCSR3 string = `-----BEGIN CERTIFICATE REQUEST----- +MIICszCCAZsCAQAwFjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3 +DQEBAQUAA4IBDwAwggEKAoIBAQDN7tHggWTtxiT5Sh5Npoif8J2BdpJjtMdpZ7Vu +NVzMxW/eojSRlq0p3nafmpjnSdSH1k/XMmPsgmv9txxEHMw1LIUJUef2QVrQTI6J +4ueu9NvexZWXZ+UxFip63PKyn/CkZRFiHCRIGzDDPxM2aApjghXy9ISMtGqDVSnr +5hQDu2U1CEiUWKMoTpyk/KlBZliDDOzaGm3cQuzKWs6Stjzpq+uX4ecJAXZg5Cj+ ++JUETH93A/VOfsiiHXoKeTnFMCsmJgEHz2DZixw8EN8XgpOp5BA2n8Y/xS+Ren5R +ZH7uNJI/SmQ0yrR+2bYR6hm+4bCzspyCfzbiuI5IS9+2eXA/AgMBAAGgWDBWBgkq +hkiG9w0BCQ4xSTBHMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcD +AQYIKwYBBQUHAwIwFgYDVR0RBA8wDYILZXhhbXBsZS5jb20wDQYJKoZIhvcNAQEL +BQADggEBAB/aPfYLbnCubYyKnxLRipoLr3TBSYFnRfcxiZR1o+L3/tuv2NlrXJjY +K13xzzPhwuZwd6iKfX3xC33sKgnUNFawyE8IuAmyhJ2cl97iA2lwoYcyuWP9TOEx +LT60zxp7PHsKo53gqaqRJ5B9RZtiv1jYdUZvynHP4J5JG7Zwaa0VNi/Cx5cwGW8K +rfvNABPUAU6xIqqYgd2heDPF6kjvpoNiOl056qIAbk0dbmpqOJf/lxKBRfqlHhSC +0qRScGu70l2Oxl89YSsfGtUyQuzTkLshI2VkEUM+W/ZauXbxLd8SyWveH3/7mDC+ +Sgi7T+lz+c1Tw+XFgkqryUwMeG2wxt8= +-----END CERTIFICATE REQUEST----- +` + +var ValidCert2 string = `-----BEGIN CERTIFICATE----- +MIIDrDCCApSgAwIBAgIURKr+jf7hj60SyAryIeN++9wDdtkwDQYJKoZIhvcNAQEL +BQAwOTELMAkGA1UEBhMCVVMxKjAoBgNVBAMMIXNlbGYtc2lnbmVkLWNlcnRpZmlj +YXRlcy1vcGVyYXRvcjAeFw0yNDAzMjcxMjQ4MDRaFw0yNTAzMjcxMjQ4MDRaMEcx +FjAUBgNVBAMMDTEwLjE1Mi4xODMuNTMxLTArBgNVBC0MJDM5YWNlMTk1LWRjNWEt +NDMyYi04MDkwLWFmZTZhYjRiNDljZjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC +AQoCggEBAIzOVs/h0bQ73kc3g5BAzOKK5yGnn3vHfJ5hYqVGrfaglN1Pe/fBUeOt +Ax4wixsRccZqEYrcx7WkeHsOR9N+h5RgdtMmrPwjA4ElCOvQ8gLMj8BCajlgRjtS +y+dKl3Zef4zpDtr6APY+VLA9QrVkpJcoWzYVjKWKo9jm1l3GmI/Rsf08GIcPMbCf +0EVsgq7441cx/PuH63w4QVTsssDe7KviN1sAm8gQXDfLjsmntG/6uflCVIqoKXjg +If+GO6iN2GqzI+HJu3cJoglY3nVMIeuc1ch7UFHy3fJxVipfe++ZjQXNconLyg1N +2QeYWdTGbaWzAcGgNcjEhK7P34eSml8CAwEAAaOBnTCBmjAhBgNVHSMEGjAYgBYE +FN/vgl9cAapV7hH9lEyM7qYS958aMB0GA1UdDgQWBBRJJDZkHr64VqTC24DPQVld +Ba3iPDAMBgNVHRMBAf8EAjAAMEgGA1UdEQRBMD+CN3ZhdWx0LWs4cy0wLnZhdWx0 +LWs4cy1lbmRwb2ludHMudmF1bHQuc3ZjLmNsdXN0ZXIubG9jYWyHBAqYtzUwDQYJ +KoZIhvcNAQELBQADggEBAEH9NTwDiSsoQt/QXkWPMBrB830K0dlwKl5WBNgVxFP+ +hSfQ86xN77jNSp2VxOksgzF9J9u/ubAXvSFsou4xdP8MevBXoFJXeqMERq5RW3gc +WyhXkzguv3dwH+n43GJFP6MQ+n9W/nPZCUQ0Iy7ueAvj0HFhGyZzAE2wxNFZdvCs +gCX3nqYpp70oZIFDrhmYwE5ij5KXlHD4/1IOfNUKCDmQDgGPLI1tVtwQLjeRq7Hg +XVelpl/LXTQawmJyvDaVT/Q9P+WqoDiMjrqF6Sy7DzNeeccWVqvqX5TVS6Ky56iS +Mvo/+PAJHkBciR5Xn+Wg2a+7vrZvT6CBoRSOTozlLSM= +-----END CERTIFICATE-----` + +func TestCSRValidationSuccess(t *testing.T) { + cases := []string{ValidCSR1, ValidCSR2, ValidCSR3} + + for i, c := range cases { + t.Run(fmt.Sprintf("ValidCSR%d", i), func(t *testing.T) { + if err := certdb.ValidateCertificateRequest(c); err != nil { + t.Errorf("Couldn't verify valid CSR: %s", err) + } + }) + } +} + +func TestCSRValidationFail(t *testing.T) { + var wrongString = "this is a real csr!!!" + var wrongStringErr = "PEM Certificate Request string not found or malformed" + var ValidCSRWithoutWhitespace = strings.ReplaceAll(ValidCSR1, "\n", "") + var ValidCSRWithoutWhitespaceErr = "PEM Certificate Request string not found or malformed" + var wrongPemType = strings.ReplaceAll(ValidCSR1, "CERTIFICATE REQUEST", "SOME RANDOM PEM TYPE") + var wrongPemTypeErr = "given PEM string not a certificate request" + var InvalidCSR = strings.ReplaceAll(ValidCSR1, "/", "p") + var InvalidCSRErr = "asn1: syntax error: invalid boolean" + + cases := []struct { + input string + expectedErr string + }{ + { + input: wrongString, + expectedErr: wrongStringErr, + }, + { + input: ValidCSRWithoutWhitespace, + expectedErr: ValidCSRWithoutWhitespaceErr, + }, + { + input: wrongPemType, + expectedErr: wrongPemTypeErr, + }, + { + input: InvalidCSR, + expectedErr: InvalidCSRErr, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("InvalidCSR%d", i), func(t *testing.T) { + err := certdb.ValidateCertificateRequest(c.input) + if err.Error() != c.expectedErr { + t.Errorf("Expected error not found:\nReceived: %s\nExpected: %s", err, c.expectedErr) + } + }) + } +} + +func TestCertValidationSuccess(t *testing.T) { + cases := []string{ValidCert2} + + for i, c := range cases { + t.Run(fmt.Sprintf("ValidCert%d", i), func(t *testing.T) { + if err := certdb.ValidateCertificate(c); err != nil { + t.Errorf("Couldn't verify valid Cert: %s", err) + } + }) + } +} + +func TestCertValidationFail(t *testing.T) { + var wrongCertString = "this is a real cert!!!" + var wrongCertStringErr = "PEM Certificate string not found or malformed" + var ValidCertWithoutWhitespace = strings.ReplaceAll(ValidCert2, "\n", "") + var ValidCertWithoutWhitespaceErr = "PEM Certificate string not found or malformed" + var wrongPemType = strings.ReplaceAll(ValidCert2, "CERTIFICATE", "SOME RANDOM PEM TYPE") + var wrongPemTypeErr = "given PEM string not a certificate" + var InvalidCert = strings.ReplaceAll(ValidCert2, "M", "i") + var InvalidCertErr = "x509: malformed certificate" + + cases := []struct { + inputCert string + expectedErr string + }{ + { + inputCert: wrongCertString, + expectedErr: wrongCertStringErr, + }, + { + inputCert: ValidCertWithoutWhitespace, + expectedErr: ValidCertWithoutWhitespaceErr, + }, + { + inputCert: wrongPemType, + expectedErr: wrongPemTypeErr, + }, + { + inputCert: InvalidCert, + expectedErr: InvalidCertErr, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("InvalidCert%d", i), func(t *testing.T) { + err := certdb.ValidateCertificate(c.inputCert) + if err.Error() != c.expectedErr { + t.Errorf("Expected error not found:\nReceived: %s\n Expected: %s", err, c.expectedErr) + } + }) + } +} + +func TestCertificateMatchesCSRSuccess(t *testing.T) { + cases := []struct { + inputCSR string + inputCert string + }{ + { + inputCSR: ValidCSR2, + inputCert: ValidCert2, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("InvalidCert%d", i), func(t *testing.T) { + err := certdb.CertificateMatchesCSR(c.inputCert, c.inputCSR) + if err != nil { + t.Errorf("Certificate did not match when it should have") + } + }) + } +} + +func TestCertificateMatchesCSRFail(t *testing.T) { + var certificateDoesNotMatchErr = "certificate does not match CSR" + + cases := []struct { + inputCSR string + inputCert string + expectedErr string + }{ + { + inputCSR: ValidCSR1, + inputCert: ValidCert2, + expectedErr: certificateDoesNotMatchErr, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("InvalidCert%d", i), func(t *testing.T) { + err := certdb.CertificateMatchesCSR(c.inputCert, c.inputCSR) + if err.Error() != c.expectedErr { + t.Errorf("Expected error not found:\nReceived: %s\n Expected: %s", err, c.expectedErr) + } + }) + } +}