Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SQLite integration #5

Merged
merged 16 commits into from
Apr 3, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.db
2 changes: 1 addition & 1 deletion cmd/gocert/main.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package main

func main() {
// ...
// ListenAndServe
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/canonical/gocert

go 1.22.1

require github.com/mattn/go-sqlite3 v1.14.22
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
115 changes: 115 additions & 0 deletions internal/certdb/certdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package certdb
kayra1 marked this conversation as resolved.
Show resolved Hide resolved

import (
"database/sql"
"fmt"

_ "github.com/mattn/go-sqlite3"
)

const queryCreateTable = "CREATE TABLE IF NOT EXISTS %s (CSR VARCHAR PRIMARY KEY UNIQUE NOT NULL, Certificate VARCHAR)"

const queryGetAllCSRs = "SELECT rowid, * FROM %s"
const queryGetCSR = "SELECT rowid, * FROM %s WHERE CSR=?"
const queryCreateCSR = "INSERT INTO %s (CSR) VALUES (?)"
const queryUpdateCSR = "UPDATE %s SET Certificate=? WHERE CSR=?"
const queryDeleteCSR = "DELETE FROM %s WHERE CSR=?"

type CertificateRequests struct {
table string
conn *sql.DB
}
kayra1 marked this conversation as resolved.
Show resolved Hide resolved

type CertificateRequest struct {
ID int
CSR string
Certificate *string
ghislainbourgeois marked this conversation as resolved.
Show resolved Hide resolved
}

func (db *CertificateRequests) RetrieveAll() ([]CertificateRequest, error) {
rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.table))
if err != nil {
return nil, err
}

var allCsrs []CertificateRequest
defer rows.Close()
for rows.Next() {
var csr CertificateRequest
if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil {
return nil, err
}
allCsrs = append(allCsrs, csr)
}
return allCsrs, nil
}

func (db *CertificateRequests) Retrieve(csr string) (*CertificateRequest, error) {
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
var newCSR CertificateRequest
row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.table), csr)
if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil {
return nil, err
}
return &newCSR, nil
}

func (db *CertificateRequests) Create(csr string) (int64, error) {
if err := ValidateCertificateRequest(csr); err != nil {
return 0, err
}
result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.table), csr)
if err != nil {
return 0, err
}
id, err := result.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}

func (db *CertificateRequests) Update(csr string, cert string) (int64, error) {
if err := ValidateCertificate(cert, csr); err != nil {
return 0, err
}
result, err := db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.table), cert, csr)
if err != nil {
return 0, err
}
id, err := result.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
}

func (db *CertificateRequests) Delete(csr string) error {
_, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.table), csr)
if err != nil {
return err
}
return nil
}

func (db *CertificateRequests) Connect(databasePath string, tableName string) error {
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
conn, err := sql.Open("sqlite3", databasePath)
if err != nil {
return err
}
db.table = tableName
db.conn = conn
if _, err := db.conn.Exec(fmt.Sprintf(queryCreateTable, db.table)); err != nil {
return err
}
return nil
}

func (db *CertificateRequests) Disconnect() error {
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
if db.conn == nil {
return nil
}
if err := db.conn.Close(); err != nil {
return err
}
return nil
}
117 changes: 117 additions & 0 deletions internal/certdb/certdb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package certdb_test

import (
"log"
"strings"
"testing"

"github.com/canonical/gocert/internal/certdb"
)

func TestConnect(t *testing.T) {
db := new(certdb.CertificateRequests)
defer db.Disconnect()
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
if err := db.Connect(":memory:", "CertificateReqs"); err != nil {
t.Fatalf("Can't connect to SQLite: %s", err)
ghislainbourgeois marked this conversation as resolved.
Show resolved Hide resolved
}
}

func TestEndToEnd(t *testing.T) {
db := new(certdb.CertificateRequests)
defer db.Disconnect()
db.Connect(":memory:", "CertificateRequests")

if _, err := db.Create(ValidCSR1); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
if _, err := db.Create(ValidCSR2); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}
if _, err := db.Create(ValidCSR3); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
}

res, err := db.RetrieveAll()
if err != nil {
t.Fatalf("Couldn't complete RetrieveAll: %s", err)
}
if len(res) != 3 {
t.Fatalf("One or more CSRs weren't found in DB")
}
retrievedCSR, err := db.Retrieve(ValidCSR1)
if err != nil {
t.Fatalf("Couldn't complete Retrieve: %s", err)
}
if retrievedCSR.CSR != ValidCSR1 {
t.Fatalf("The CSR from the database doesn't match the CSR that was given")
}

if err = db.Delete(ValidCSR1); err != nil {
t.Fatalf("Couldn't complete Delete: %s", err)
}
res, _ = db.RetrieveAll()
if len(res) != 2 {
t.Fatalf("CSR's weren't deleted from the DB properly")
}

_, err = db.Update(ValidCSR2, ValidCert2)
if err != nil {
t.Fatalf("Couldn't complete Update: %s", err)
}
retrievedCSR, _ = db.Retrieve(ValidCSR2)
if *retrievedCSR.Certificate != ValidCert2 {
t.Fatalf("The certificate that was uploaded does not match the certificate that was given: Retrieved: %s\nGiven: %s", *retrievedCSR.Certificate, ValidCert2)
}
}

func TestCreateFails(t *testing.T) {
db := new(certdb.CertificateRequests)
db.Connect(":memory:", "CertificateReqs")
defer db.Disconnect()

InvalidCSR := strings.ReplaceAll(ValidCSR1, "/", "+")
if _, err := db.Create(InvalidCSR); err == nil {
t.Fatalf("Expected error due to invalid CSR")
}

db.Create(ValidCSR1)
if _, err := db.Create(ValidCSR1); err == nil {
t.Fatalf("Expected error due to duplicate CSR")
}
}

func TestUpdateFails(t *testing.T) {
db := new(certdb.CertificateRequests)
defer db.Disconnect()
db.Connect(":memory:", "CertificateRequests")

db.Create(ValidCSR1)
db.Create(ValidCSR2)
InvalidCert := strings.ReplaceAll(ValidCert2, "/", "+")
if _, err := db.Update(ValidCSR2, InvalidCert); err == nil {
t.Fatalf("Expected updating with invalid cert to fail")
}
if _, err := db.Update(ValidCSR1, ValidCert2); err == nil {
t.Fatalf("Expected updating with mismatched cert to fail")
}
}

func TestRetrieve(t *testing.T) {
db := new(certdb.CertificateRequests)
defer db.Disconnect()
db.Connect(":memory:", "CertificateRequests")

db.Create(ValidCSR1)
if _, err := db.Retrieve(ValidCSR2); err == nil {
t.Fatalf("Expected failure looking for nonexistent CSR")
}

}

func Example() {
db := new(certdb.CertificateRequests)
if err := db.Connect("./certs.db", "CertificateReq"); err != nil {
log.Fatalln(err)
}
defer db.Disconnect()
}
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
48 changes: 48 additions & 0 deletions internal/certdb/validation.go
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package certdb

import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)

func ValidateCertificateRequest(csrString string) error {
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
block, _ := pem.Decode([]byte(csrString))
if block == nil {
return errors.New("PEM Certificate Request string not found or malformed")
}
if block.Type != "CERTIFICATE REQUEST" {
return errors.New("given PEM string not a certificate request")
}
_, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return err
}
return nil
}

func ValidateCertificate(certString string, csrString string) error {
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
if err := ValidateCertificateRequest(csrString); err != nil {
return err
}
csrBlock, _ := pem.Decode([]byte(csrString))
csr, _ := x509.ParseCertificateRequest(csrBlock.Bytes)
gruyaume marked this conversation as resolved.
Show resolved Hide resolved
certBlock, _ := pem.Decode([]byte(certString))
if certBlock == nil {
return errors.New("PEM Certificate string not found or malformed")
}
if certBlock.Type != "CERTIFICATE" {
return errors.New("given PEM string not a certificate")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return err
}
certKey := cert.PublicKey.(*rsa.PublicKey)
csrKey := csr.PublicKey.(*rsa.PublicKey)
if !csrKey.Equal(certKey) {
return errors.New("certificate does not match CSR")
}
return nil
}
Loading
Loading