-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
584 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.db |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
package main | ||
|
||
func main() { | ||
// ... | ||
// ListenAndServe | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
module github.com/canonical/gocert | ||
|
||
go 1.22.1 | ||
|
||
require github.com/mattn/go-sqlite3 v1.14.22 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= | ||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
// Package certdb provides a simplistic ORM to communicate with an SQL database for storage | ||
package certdb | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
|
||
_ "github.com/mattn/go-sqlite3" | ||
) | ||
|
||
const queryCreateTable = "CREATE TABLE IF NOT EXISTS %s (CSR VARCHAR PRIMARY KEY UNIQUE NOT NULL, Certificate VARCHAR DEFAULT '')" | ||
|
||
const queryGetAllCSRs = "SELECT rowid, * FROM %s" | ||
const queryGetCSR = "SELECT rowid, * FROM %s WHERE CSR=?" | ||
const queryCreateCSR = "INSERT INTO %s (CSR) VALUES (?)" | ||
const queryUpdateCSR = "UPDATE %s SET Certificate=? WHERE CSR=?" | ||
const queryDeleteCSR = "DELETE FROM %s WHERE CSR=?" | ||
|
||
// CertificateRequestRepository is the object used to communicate with the established repository. | ||
type CertificateRequestsRepository struct { | ||
table string | ||
conn *sql.DB | ||
} | ||
|
||
// A CertificateRequest struct represents an entry in the database. | ||
// The object contains a Certificate Request, its matching Certificate if any, and the row ID. | ||
type CertificateRequest struct { | ||
ID int | ||
CSR string | ||
Certificate string | ||
} | ||
|
||
// RetrieveAll gets every CertificateRequest entry in the table. | ||
func (db *CertificateRequestsRepository) RetrieveAll() ([]CertificateRequest, error) { | ||
rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.table)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var allCsrs []CertificateRequest | ||
defer rows.Close() | ||
for rows.Next() { | ||
var csr CertificateRequest | ||
if err := rows.Scan(&csr.ID, &csr.CSR, &csr.Certificate); err != nil { | ||
return nil, err | ||
} | ||
allCsrs = append(allCsrs, csr) | ||
} | ||
return allCsrs, nil | ||
} | ||
|
||
// Retrieve gets a given CSR from the repository. | ||
// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object. | ||
func (db *CertificateRequestsRepository) Retrieve(csr string) (CertificateRequest, error) { | ||
var newCSR CertificateRequest | ||
row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.table), csr) | ||
if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil { | ||
return newCSR, err | ||
} | ||
return newCSR, nil | ||
} | ||
|
||
// Create creates a new entry in the repository. | ||
// The given CSR must be valid and unique | ||
func (db *CertificateRequestsRepository) Create(csr string) (int64, error) { | ||
if err := ValidateCertificateRequest(csr); err != nil { | ||
return 0, err | ||
} | ||
result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.table), csr) | ||
if err != nil { | ||
return 0, err | ||
} | ||
id, err := result.LastInsertId() | ||
if err != nil { | ||
return 0, err | ||
} | ||
return id, nil | ||
} | ||
|
||
// Update adds a new cert to the given CSR in the repository. | ||
// The given certificate must share the public key of the CSR and must be valid. | ||
func (db *CertificateRequestsRepository) Update(csr string, cert string) (int64, error) { | ||
if err := ValidateCertificate(cert); err != nil { | ||
return 0, err | ||
} | ||
if err := CertificateMatchesCSR(cert, csr); err != nil { | ||
return 0, err | ||
} | ||
result, err := db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.table), cert, csr) | ||
if err != nil { | ||
return 0, err | ||
} | ||
id, err := result.LastInsertId() | ||
if err != nil { | ||
return 0, err | ||
} | ||
return id, nil | ||
} | ||
|
||
// Delete removes a CSR from the database alongside the certificate that may have been generated for it. | ||
func (db *CertificateRequestsRepository) Delete(csr string) error { | ||
_, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.table), csr) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// Close closes the connection to the repository cleanly. | ||
func (db *CertificateRequestsRepository) Close() error { | ||
if db.conn == nil { | ||
return nil | ||
} | ||
if err := db.conn.Close(); err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// NewCertificateRequestsRepository connects to a given table in a given database, | ||
// stores the connection information and returns an object containing the information. | ||
// The database path must be a valid file path or ":memory:". | ||
// The table will be created if it doesn't exist in the format expected by the package. | ||
func NewCertificateRequestsRepository(databasePath string, tableName string) (*CertificateRequestsRepository, error) { | ||
conn, err := sql.Open("sqlite3", databasePath) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if _, err := conn.Exec(fmt.Sprintf(queryCreateTable, tableName)); err != nil { | ||
return nil, err | ||
} | ||
db := new(CertificateRequestsRepository) | ||
db.conn = conn | ||
db.table = tableName | ||
return db, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
package certdb_test | ||
|
||
import ( | ||
"log" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/canonical/gocert/internal/certdb" | ||
) | ||
|
||
func TestConnect(t *testing.T) { | ||
db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs") | ||
if err != nil { | ||
t.Fatalf("Can't connect to SQLite: %s", err) | ||
} | ||
db.Close() | ||
} | ||
|
||
func TestEndToEnd(t *testing.T) { | ||
db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck | ||
if err != nil { | ||
t.Fatalf("Couldn't complete NewCertificateRequestsRepository: %s", err) | ||
} | ||
defer db.Close() | ||
|
||
if _, err := db.Create(ValidCSR1); err != nil { | ||
t.Fatalf("Couldn't complete Create: %s", err) | ||
} | ||
if _, err := db.Create(ValidCSR2); err != nil { | ||
t.Fatalf("Couldn't complete Create: %s", err) | ||
} | ||
if _, err := db.Create(ValidCSR3); err != nil { | ||
t.Fatalf("Couldn't complete Create: %s", err) | ||
} | ||
|
||
res, err := db.RetrieveAll() | ||
if err != nil { | ||
t.Fatalf("Couldn't complete RetrieveAll: %s", err) | ||
} | ||
if len(res) != 3 { | ||
t.Fatalf("One or more CSRs weren't found in DB") | ||
} | ||
retrievedCSR, err := db.Retrieve(ValidCSR1) | ||
if err != nil { | ||
t.Fatalf("Couldn't complete Retrieve: %s", err) | ||
} | ||
if retrievedCSR.CSR != ValidCSR1 { | ||
t.Fatalf("The CSR from the database doesn't match the CSR that was given") | ||
} | ||
|
||
if err = db.Delete(ValidCSR1); err != nil { | ||
t.Fatalf("Couldn't complete Delete: %s", err) | ||
} | ||
res, _ = db.RetrieveAll() | ||
if len(res) != 2 { | ||
t.Fatalf("CSR's weren't deleted from the DB properly") | ||
} | ||
|
||
_, err = db.Update(ValidCSR2, ValidCert2) | ||
if err != nil { | ||
t.Fatalf("Couldn't complete Update: %s", err) | ||
} | ||
retrievedCSR, _ = db.Retrieve(ValidCSR2) | ||
if retrievedCSR.Certificate != ValidCert2 { | ||
t.Fatalf("The certificate that was uploaded does not match the certificate that was given: Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, ValidCert2) | ||
} | ||
} | ||
|
||
func TestCreateFails(t *testing.T) { | ||
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs") //nolint:errcheck | ||
defer db.Close() | ||
|
||
InvalidCSR := strings.ReplaceAll(ValidCSR1, "/", "+") | ||
if _, err := db.Create(InvalidCSR); err == nil { | ||
t.Fatalf("Expected error due to invalid CSR") | ||
} | ||
|
||
db.Create(ValidCSR1) //nolint:errcheck | ||
if _, err := db.Create(ValidCSR1); err == nil { | ||
t.Fatalf("Expected error due to duplicate CSR") | ||
} | ||
} | ||
|
||
func TestUpdateFails(t *testing.T) { | ||
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck | ||
defer db.Close() | ||
|
||
db.Create(ValidCSR1) //nolint:errcheck | ||
db.Create(ValidCSR2) //nolint:errcheck | ||
InvalidCert := strings.ReplaceAll(ValidCert2, "/", "+") | ||
if _, err := db.Update(ValidCSR2, InvalidCert); err == nil { | ||
t.Fatalf("Expected updating with invalid cert to fail") | ||
} | ||
if _, err := db.Update(ValidCSR1, ValidCert2); err == nil { | ||
t.Fatalf("Expected updating with mismatched cert to fail") | ||
} | ||
} | ||
|
||
func TestRetrieve(t *testing.T) { | ||
db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck | ||
defer db.Close() | ||
|
||
db.Create(ValidCSR1) //nolint:errcheck | ||
if _, err := db.Retrieve(ValidCSR2); err == nil { | ||
t.Fatalf("Expected failure looking for nonexistent CSR") | ||
} | ||
|
||
} | ||
|
||
func Example() { | ||
db, err := certdb.NewCertificateRequestsRepository("./certs.db", "CertificateReq") | ||
if err != nil { | ||
log.Fatalln(err) | ||
} | ||
_, err = db.Create(ValidCSR2) | ||
if err != nil { | ||
log.Fatalln(err) | ||
} | ||
_, err = db.Update(ValidCSR2, ValidCert2) | ||
if err != nil { | ||
log.Fatalln(err) | ||
} | ||
entry, err := db.Retrieve(ValidCSR2) | ||
if err != nil { | ||
log.Fatalln(err) | ||
} | ||
if entry.Certificate != ValidCert2 { | ||
log.Fatalln("Retrieved Certificate doesn't match Stored Certificate") | ||
} | ||
err = db.Close() | ||
if err != nil { | ||
log.Fatalln(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package certdb | ||
|
||
import ( | ||
"crypto/rsa" | ||
"crypto/x509" | ||
"encoding/pem" | ||
"errors" | ||
) | ||
|
||
// ValidateCertificateRequest validates the given CSR string to the following: | ||
// The string must be a valid PEM string, and should be of type CERTIFICATE REQUEST | ||
// The PEM string should be able to be parsed into a x509 Certificate Request | ||
func ValidateCertificateRequest(csr string) error { | ||
block, _ := pem.Decode([]byte(csr)) | ||
if block == nil { | ||
return errors.New("PEM Certificate Request string not found or malformed") | ||
} | ||
if block.Type != "CERTIFICATE REQUEST" { | ||
return errors.New("given PEM string not a certificate request") | ||
} | ||
_, err := x509.ParseCertificateRequest(block.Bytes) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// ValidateCertificate validates the given Cert string to the following: | ||
// The cert string must be a valid PEM string, and should be of type CERTIFICATE | ||
// The PEM string should be able to be parsed into a x509 Certificate | ||
func ValidateCertificate(cert string) error { | ||
certBlock, _ := pem.Decode([]byte(cert)) | ||
if certBlock == nil { | ||
return errors.New("PEM Certificate string not found or malformed") | ||
} | ||
if certBlock.Type != "CERTIFICATE" { | ||
return errors.New("given PEM string not a certificate") | ||
} | ||
_, err := x509.ParseCertificate(certBlock.Bytes) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// CertificateMatchesCSR makes sure that the given certificate and CSR match. | ||
// The given CSR and Cert must pass their respective validation functions | ||
// The given cert and CSR must share the same public key | ||
func CertificateMatchesCSR(cert string, csr string) error { | ||
if err := ValidateCertificateRequest(csr); err != nil { | ||
return err | ||
} | ||
if err := ValidateCertificate(cert); err != nil { | ||
return err | ||
} | ||
csrBlock, _ := pem.Decode([]byte(csr)) | ||
parsedCSR, _ := x509.ParseCertificateRequest(csrBlock.Bytes) | ||
certBlock, _ := pem.Decode([]byte(cert)) | ||
parsedCERT, _ := x509.ParseCertificate(certBlock.Bytes) | ||
certKey := parsedCERT.PublicKey.(*rsa.PublicKey) | ||
csrKey := parsedCSR.PublicKey.(*rsa.PublicKey) | ||
if !csrKey.Equal(certKey) { | ||
return errors.New("certificate does not match CSR") | ||
} | ||
return nil | ||
} |
Oops, something went wrong.