Skip to content

Commit

Permalink
feat: SQLite integration (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
kayra1 authored Apr 3, 2024
1 parent 9043782 commit 435f407
Show file tree
Hide file tree
Showing 8 changed files with 584 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.db
2 changes: 1 addition & 1 deletion cmd/gocert/main.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package main

func main() {
// ...
// ListenAndServe
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/canonical/gocert

go 1.22.1

require github.com/mattn/go-sqlite3 v1.14.22
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
136 changes: 136 additions & 0 deletions internal/certdb/certdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Package certdb provides a simplistic ORM to communicate with an SQL database for storage
package certdb

import (
"database/sql"
"fmt"

_ "github.com/mattn/go-sqlite3"
)

const queryCreateTable = "CREATE TABLE IF NOT EXISTS %s (CSR VARCHAR PRIMARY KEY UNIQUE NOT NULL, Certificate VARCHAR DEFAULT '')"

const queryGetAllCSRs = "SELECT rowid, * FROM %s"
const queryGetCSR = "SELECT rowid, * FROM %s WHERE CSR=?"
const queryCreateCSR = "INSERT INTO %s (CSR) VALUES (?)"
const queryUpdateCSR = "UPDATE %s SET Certificate=? WHERE CSR=?"
const queryDeleteCSR = "DELETE FROM %s WHERE CSR=?"

// CertificateRequestRepository is the object used to communicate with the established repository.
type CertificateRequestsRepository struct {
table string
conn *sql.DB
}

// A CertificateRequest struct represents an entry in the database.
// The object contains a Certificate Request, its matching Certificate if any, and the row ID.
type CertificateRequest struct {
ID int
CSR string
Certificate string
}

// RetrieveAll gets every CertificateRequest entry in the table.
func (db *CertificateRequestsRepository) RetrieveAll() ([]CertificateRequest, error) {
rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.table))
if err != nil {
return nil, err
}

var allCsrs []CertificateRequest
defer rows.Close()
for rows.Next() {
var csr CertificateRequest
if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil {
return nil, err
}
allCsrs = append(allCsrs, csr)
}
return allCsrs, nil
}

// Retrieve gets a given CSR from the repository.
// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object.
func (db *CertificateRequestsRepository) Retrieve(csr string) (CertificateRequest, error) {
var newCSR CertificateRequest
row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.table), csr)
if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil {
return newCSR, err
}
return newCSR, nil
}

// Create creates a new entry in the repository.
// The given CSR must be valid and unique
func (db *CertificateRequestsRepository) Create(csr string) (int64, error) {
if err := ValidateCertificateRequest(csr); err != nil {
return 0, err
}
result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.table), csr)
if err != nil {
return 0, err
}
id, err := result.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}

// Update adds a new cert to the given CSR in the repository.
// The given certificate must share the public key of the CSR and must be valid.
func (db *CertificateRequestsRepository) Update(csr string, cert string) (int64, error) {
if err := ValidateCertificate(cert); err != nil {
return 0, err
}
if err := CertificateMatchesCSR(cert, csr); err != nil {
return 0, err
}
result, err := db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.table), cert, csr)
if err != nil {
return 0, err
}
id, err := result.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}

// Delete removes a CSR from the database alongside the certificate that may have been generated for it.
func (db *CertificateRequestsRepository) Delete(csr string) error {
_, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.table), csr)
if err != nil {
return err
}
return nil
}

// Close closes the connection to the repository cleanly.
func (db *CertificateRequestsRepository) Close() error {
if db.conn == nil {
return nil
}
if err := db.conn.Close(); err != nil {
return err
}
return nil
}

// NewCertificateRequestsRepository connects to a given table in a given database,
// stores the connection information and returns an object containing the information.
// The database path must be a valid file path or ":memory:".
// The table will be created if it doesn't exist in the format expected by the package.
func NewCertificateRequestsRepository(databasePath string, tableName string) (*CertificateRequestsRepository, error) {
conn, err := sql.Open("sqlite3", databasePath)
if err != nil {
return nil, err
}
if _, err := conn.Exec(fmt.Sprintf(queryCreateTable, tableName)); err != nil {
return nil, err
}
db := new(CertificateRequestsRepository)
db.conn = conn
db.table = tableName
return db, nil
}
134 changes: 134 additions & 0 deletions internal/certdb/certdb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package certdb_test

import (
"log"
"strings"
"testing"

"github.com/canonical/gocert/internal/certdb"
)

func TestConnect(t *testing.T) {
db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs")
if err != nil {
t.Fatalf("Can't connect to SQLite: %s", err)
}
db.Close()
}

func TestEndToEnd(t *testing.T) {
db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck
if err != nil {
t.Fatalf("Couldn't complete NewCertificateRequestsRepository: %s", err)
}
defer db.Close()

if _, err := db.Create(ValidCSR1); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
if _, err := db.Create(ValidCSR2); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
if _, err := db.Create(ValidCSR3); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}

res, err := db.RetrieveAll()
if err != nil {
t.Fatalf("Couldn't complete RetrieveAll: %s", err)
}
if len(res) != 3 {
t.Fatalf("One or more CSRs weren't found in DB")
}
retrievedCSR, err := db.Retrieve(ValidCSR1)
if err != nil {
t.Fatalf("Couldn't complete Retrieve: %s", err)
}
if retrievedCSR.CSR != ValidCSR1 {
t.Fatalf("The CSR from the database doesn't match the CSR that was given")
}

if err = db.Delete(ValidCSR1); err != nil {
t.Fatalf("Couldn't complete Delete: %s", err)
}
res, _ = db.RetrieveAll()
if len(res) != 2 {
t.Fatalf("CSR's weren't deleted from the DB properly")
}

_, err = db.Update(ValidCSR2, ValidCert2)
if err != nil {
t.Fatalf("Couldn't complete Update: %s", err)
}
retrievedCSR, _ = db.Retrieve(ValidCSR2)
if retrievedCSR.Certificate != ValidCert2 {
t.Fatalf("The certificate that was uploaded does not match the certificate that was given: Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, ValidCert2)
}
}

func TestCreateFails(t *testing.T) {
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs") //nolint:errcheck
defer db.Close()

InvalidCSR := strings.ReplaceAll(ValidCSR1, "/", "+")
if _, err := db.Create(InvalidCSR); err == nil {
t.Fatalf("Expected error due to invalid CSR")
}

db.Create(ValidCSR1) //nolint:errcheck
if _, err := db.Create(ValidCSR1); err == nil {
t.Fatalf("Expected error due to duplicate CSR")
}
}

func TestUpdateFails(t *testing.T) {
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck
defer db.Close()

db.Create(ValidCSR1) //nolint:errcheck
db.Create(ValidCSR2) //nolint:errcheck
InvalidCert := strings.ReplaceAll(ValidCert2, "/", "+")
if _, err := db.Update(ValidCSR2, InvalidCert); err == nil {
t.Fatalf("Expected updating with invalid cert to fail")
}
if _, err := db.Update(ValidCSR1, ValidCert2); err == nil {
t.Fatalf("Expected updating with mismatched cert to fail")
}
}

func TestRetrieve(t *testing.T) {
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck
defer db.Close()

db.Create(ValidCSR1) //nolint:errcheck
if _, err := db.Retrieve(ValidCSR2); err == nil {
t.Fatalf("Expected failure looking for nonexistent CSR")
}

}

func Example() {
db, err := certdb.NewCertificateRequestsRepository("./certs.db", "CertificateReq")
if err != nil {
log.Fatalln(err)
}
_, err = db.Create(ValidCSR2)
if err != nil {
log.Fatalln(err)
}
_, err = db.Update(ValidCSR2, ValidCert2)
if err != nil {
log.Fatalln(err)
}
entry, err := db.Retrieve(ValidCSR2)
if err != nil {
log.Fatalln(err)
}
if entry.Certificate != ValidCert2 {
log.Fatalln("Retrieved Certificate doesn't match Stored Certificate")
}
err = db.Close()
if err != nil {
log.Fatalln(err)
}
}
66 changes: 66 additions & 0 deletions internal/certdb/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package certdb

import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)

// ValidateCertificateRequest validates the given CSR string to the following:
// The string must be a valid PEM string, and should be of type CERTIFICATE REQUEST
// The PEM string should be able to be parsed into a x509 Certificate Request
func ValidateCertificateRequest(csr string) error {
block, _ := pem.Decode([]byte(csr))
if block == nil {
return errors.New("PEM Certificate Request string not found or malformed")
}
if block.Type != "CERTIFICATE REQUEST" {
return errors.New("given PEM string not a certificate request")
}
_, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return err
}
return nil
}

// ValidateCertificate validates the given Cert string to the following:
// The cert string must be a valid PEM string, and should be of type CERTIFICATE
// The PEM string should be able to be parsed into a x509 Certificate
func ValidateCertificate(cert string) error {
certBlock, _ := pem.Decode([]byte(cert))
if certBlock == nil {
return errors.New("PEM Certificate string not found or malformed")
}
if certBlock.Type != "CERTIFICATE" {
return errors.New("given PEM string not a certificate")
}
_, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return err
}
return nil
}

// CertificateMatchesCSR makes sure that the given certificate and CSR match.
// The given CSR and Cert must pass their respective validation functions
// The given cert and CSR must share the same public key
func CertificateMatchesCSR(cert string, csr string) error {
if err := ValidateCertificateRequest(csr); err != nil {
return err
}
if err := ValidateCertificate(cert); err != nil {
return err
}
csrBlock, _ := pem.Decode([]byte(csr))
parsedCSR, _ := x509.ParseCertificateRequest(csrBlock.Bytes)
certBlock, _ := pem.Decode([]byte(cert))
parsedCERT, _ := x509.ParseCertificate(certBlock.Bytes)
certKey := parsedCERT.PublicKey.(*rsa.PublicKey)
csrKey := parsedCSR.PublicKey.(*rsa.PublicKey)
if !csrKey.Equal(certKey) {
return errors.New("certificate does not match CSR")
}
return nil
}
Loading

0 comments on commit 435f407

Please sign in to comment.