Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SQLite integration #5

Merged
merged 16 commits into from
Apr 3, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.db
2 changes: 1 addition & 1 deletion cmd/gocert/main.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package main

func main() {
// ...
// ListenAndServe
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/canonical/gocert

go 1.22.1

require github.com/mattn/go-sqlite3 v1.14.22
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
136 changes: 136 additions & 0 deletions internal/certdb/certdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Package certdb provides a simplistic ORM to communicate with an SQL database for storage
package certdb
kayra1 marked this conversation as resolved.
Show resolved Hide resolved

import (
"database/sql"
"fmt"

_ "github.com/mattn/go-sqlite3"
)

const queryCreateTable = "CREATE TABLE IF NOT EXISTS %s (CSR VARCHAR PRIMARY KEY UNIQUE NOT NULL, Certificate VARCHAR DEFAULT '')"

const queryGetAllCSRs = "SELECT rowid, * FROM %s"
const queryGetCSR = "SELECT rowid, * FROM %s WHERE CSR=?"
const queryCreateCSR = "INSERT INTO %s (CSR) VALUES (?)"
const queryUpdateCSR = "UPDATE %s SET Certificate=? WHERE CSR=?"
const queryDeleteCSR = "DELETE FROM %s WHERE CSR=?"

// CertificateRequestRepository is the object used to communicate with the established repository.
type CertificateRequestsRepository struct {
table string
conn *sql.DB
}

// A CertificateRequest struct represents an entry in the database.
// The object contains a Certificate Request, its matching Certificate if any, and the row ID.
type CertificateRequest struct {
ID int
CSR string
Certificate string
}

// RetrieveAll gets every CertificateRequest entry in the table.
func (db *CertificateRequestsRepository) RetrieveAll() ([]CertificateRequest, error) {
rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.table))
if err != nil {
return nil, err
}

var allCsrs []CertificateRequest
defer rows.Close()
for rows.Next() {
var csr CertificateRequest
if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil {
return nil, err
}
allCsrs = append(allCsrs, csr)
}
return allCsrs, nil
}

// Retrieve gets a given CSR from the repository.
// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object.
func (db *CertificateRequestsRepository) Retrieve(csr string) (CertificateRequest, error) {
var newCSR CertificateRequest
row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.table), csr)
if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil {
newCSR.ID = -1
newCSR.Certificate = ""
newCSR.CSR = ""
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
return newCSR, err
}
return newCSR, nil
}

// Create creates a new entry in the repository.
// The given CSR must be valid and unique
func (db *CertificateRequestsRepository) Create(csr string) (int64, error) {
if err := ValidateCertificateRequest(csr); err != nil {
return 0, err
}
result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.table), csr)
if err != nil {
return 0, err
}
id, err := result.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}

// Update adds a new cert to the given CSR in the repository.
// The given certificate must share the public key of the CSR and must be valid.
func (db *CertificateRequestsRepository) Update(csr string, cert string) (int64, error) {
if err := ValidateCertificate(cert, csr); err != nil {
return 0, err
}
result, err := db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.table), cert, csr)
if err != nil {
return 0, err
}
id, err := result.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}

// Delete removes a CSR from the database alongside the certificate that may have been generated for it.
func (db *CertificateRequestsRepository) Delete(csr string) error {
_, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.table), csr)
if err != nil {
return err
}
return nil
}

// Close closes the connection to the repository cleanly.
func (db *CertificateRequestsRepository) Close() error {
if db.conn == nil {
return nil
}
if err := db.conn.Close(); err != nil {
return err
}
return nil
}

// NewCertificateRequestsRepository connects to a given table in a given database,
// stores the connection information and returns an object containing the information.
// The database path must be a valid file path or ":memory:".
// The table will be created if it doesn't exist in the format expected by the package.
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
func NewCertificateRequestsRepository(databasePath string, tableName string) (*CertificateRequestsRepository, error) {
conn, err := sql.Open("sqlite3", databasePath)
if err != nil {
return nil, err
}
if _, err := conn.Exec(fmt.Sprintf(queryCreateTable, tableName)); err != nil {
return nil, err
}
db := new(CertificateRequestsRepository)
db.conn = conn
db.table = tableName
return db, nil
}
134 changes: 134 additions & 0 deletions internal/certdb/certdb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package certdb_test

import (
"log"
"strings"
"testing"

"github.com/canonical/gocert/internal/certdb"
)

func TestConnect(t *testing.T) {
db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs")
if err != nil {
t.Fatalf("Can't connect to SQLite: %s", err)
ghislainbourgeois marked this conversation as resolved.
Show resolved Hide resolved
}
db.Close()
}

func TestEndToEnd(t *testing.T) {
db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck
if err != nil {
t.Fatalf("Couldn't complete NewCertificateRequestsRepository: %s", err)
}
defer db.Close()

if _, err := db.Create(ValidCSR1); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
if _, err := db.Create(ValidCSR2); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
if _, err := db.Create(ValidCSR3); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}

res, err := db.RetrieveAll()
if err != nil {
t.Fatalf("Couldn't complete RetrieveAll: %s", err)
}
if len(res) != 3 {
t.Fatalf("One or more CSRs weren't found in DB")
}
retrievedCSR, err := db.Retrieve(ValidCSR1)
if err != nil {
t.Fatalf("Couldn't complete Retrieve: %s", err)
}
if retrievedCSR.CSR != ValidCSR1 {
t.Fatalf("The CSR from the database doesn't match the CSR that was given")
}

if err = db.Delete(ValidCSR1); err != nil {
t.Fatalf("Couldn't complete Delete: %s", err)
}
res, _ = db.RetrieveAll()
if len(res) != 2 {
t.Fatalf("CSR's weren't deleted from the DB properly")
}

_, err = db.Update(ValidCSR2, ValidCert2)
if err != nil {
t.Fatalf("Couldn't complete Update: %s", err)
}
retrievedCSR, _ = db.Retrieve(ValidCSR2)
if retrievedCSR.Certificate != ValidCert2 {
t.Fatalf("The certificate that was uploaded does not match the certificate that was given: Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, ValidCert2)
}
}

func TestCreateFails(t *testing.T) {
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs") //nolint:errcheck
defer db.Close()

InvalidCSR := strings.ReplaceAll(ValidCSR1, "/", "+")
if _, err := db.Create(InvalidCSR); err == nil {
t.Fatalf("Expected error due to invalid CSR")
}

db.Create(ValidCSR1) //nolint:errcheck
if _, err := db.Create(ValidCSR1); err == nil {
t.Fatalf("Expected error due to duplicate CSR")
}
}

func TestUpdateFails(t *testing.T) {
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck
defer db.Close()

db.Create(ValidCSR1) //nolint:errcheck
db.Create(ValidCSR2) //nolint:errcheck
InvalidCert := strings.ReplaceAll(ValidCert2, "/", "+")
if _, err := db.Update(ValidCSR2, InvalidCert); err == nil {
t.Fatalf("Expected updating with invalid cert to fail")
}
if _, err := db.Update(ValidCSR1, ValidCert2); err == nil {
t.Fatalf("Expected updating with mismatched cert to fail")
}
}

func TestRetrieve(t *testing.T) {
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck
defer db.Close()

db.Create(ValidCSR1) //nolint:errcheck
if _, err := db.Retrieve(ValidCSR2); err == nil {
t.Fatalf("Expected failure looking for nonexistent CSR")
}

}

func Example() {
db, err := certdb.NewCertificateRequestsRepository("./certs.db", "CertificateReq")
if err != nil {
log.Fatalln(err)
}
_, err = db.Create(ValidCSR2)
if err != nil {
log.Fatalln(err)
}
_, err = db.Update(ValidCSR2, ValidCert2)
if err != nil {
log.Fatalln(err)
}
entry, err := db.Retrieve(ValidCSR2)
if err != nil {
log.Fatalln(err)
}
if entry.Certificate != ValidCert2 {
log.Fatalln("Retrieved Certificate doesn't match Stored Certificate")
}
err = db.Close()
if err != nil {
log.Fatalln(err)
}
}
56 changes: 56 additions & 0 deletions internal/certdb/validation.go
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package certdb

import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)

// ValidateCertificateRequest validates the given CSR string to the following:
// The string must be a valid PEM string, and should be of type CERTIFICATE REQUEST
// The PEM string should be able to be parsed into a x509 Certificate Request
func ValidateCertificateRequest(csrString string) error {
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
block, _ := pem.Decode([]byte(csrString))
if block == nil {
return errors.New("PEM Certificate Request string not found or malformed")
}
if block.Type != "CERTIFICATE REQUEST" {
return errors.New("given PEM string not a certificate request")
}
_, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return err
}
return nil
}

// ValidateCertificate validates the given Cert string to the following:
// The given CSR must pass the validation provided by ValidateCertificateRequest
// The cert string must be a valid PEM string, and should be of type CERTIFICATE
// The PEM string should be able to be parsed into a x509 Certificate
// The given cert and CSR must share the same public key
func ValidateCertificate(certString string, csrString string) error {
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
if err := ValidateCertificateRequest(csrString); err != nil {
return err
}
csrBlock, _ := pem.Decode([]byte(csrString))
csr, _ := x509.ParseCertificateRequest(csrBlock.Bytes)
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
certBlock, _ := pem.Decode([]byte(certString))
if certBlock == nil {
return errors.New("PEM Certificate string not found or malformed")
}
if certBlock.Type != "CERTIFICATE" {
return errors.New("given PEM string not a certificate")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return err
}
certKey := cert.PublicKey.(*rsa.PublicKey)
csrKey := csr.PublicKey.(*rsa.PublicKey)
if !csrKey.Equal(certKey) {
return errors.New("certificate does not match CSR")
}
return nil
}
Loading
Loading