diff --git a/internal/certdb/certdb.go b/internal/certdb/certdb.go index f163e6a..4c831eb 100644 --- a/internal/certdb/certdb.go +++ b/internal/certdb/certdb.go @@ -1,3 +1,4 @@ +// Package certdb provides a simplistic ORM to communicate with an SQL database for storage package certdb import ( @@ -15,6 +16,7 @@ const queryCreateCSR = "INSERT INTO %s (CSR) VALUES (?)" const queryUpdateCSR = "UPDATE %s SET Certificate=? WHERE CSR=?" const queryDeleteCSR = "DELETE FROM %s WHERE CSR=?" +// CertificateRequestRepository is the object used to communicate with the established repository. type CertificateRequestsRepository struct { table string conn *sql.DB @@ -91,19 +93,6 @@ func (db *CertificateRequestsRepository) Delete(csr string) error { return nil } -func (db *CertificateRequestsRepository) Connect(databasePath string, tableName string) error { - conn, err := sql.Open("sqlite3", databasePath) - if err != nil { - return err - } - db.table = tableName - db.conn = conn - if _, err := db.conn.Exec(fmt.Sprintf(queryCreateTable, db.table)); err != nil { - return err - } - return nil -} - func (db *CertificateRequestsRepository) Close() error { if db.conn == nil { return nil @@ -113,3 +102,17 @@ func (db *CertificateRequestsRepository) Close() error { } return nil } + +func NewCertificateRequestsRepository(databasePath string, tableName string) (*CertificateRequestsRepository, error) { + conn, err := sql.Open("sqlite3", databasePath) + if err != nil { + return nil, err + } + if _, err := conn.Exec(fmt.Sprintf(queryCreateTable, tableName)); err != nil { + return nil, err + } + db := new(CertificateRequestsRepository) + db.conn = conn + db.table = tableName + return db, nil +} diff --git a/internal/certdb/certdb_test.go b/internal/certdb/certdb_test.go index f937729..9cbd903 100644 --- a/internal/certdb/certdb_test.go +++ b/internal/certdb/certdb_test.go @@ -9,17 +9,19 @@ import ( ) func TestConnect(t *testing.T) { - db := new(certdb.CertificateRequestsRepository) - defer db.Close() - if err := db.Connect(":memory:", "CertificateReqs"); err != nil { + db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs") + if err != nil { t.Fatalf("Can't connect to SQLite: %s", err) } + db.Close() } func TestEndToEnd(t *testing.T) { - db := new(certdb.CertificateRequestsRepository) + db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck + if err != nil { + t.Fatalf("Couldn't complete NewCertificateRequestsRepository: %s", err) + } defer db.Close() - db.Connect(":memory:", "CertificateRequests") //nolint:errcheck if _, err := db.Create(ValidCSR1); err != nil { t.Fatalf("Couldn't complete Create: %s", err) @@ -65,8 +67,7 @@ func TestEndToEnd(t *testing.T) { } func TestCreateFails(t *testing.T) { - db := new(certdb.CertificateRequestsRepository) - db.Connect(":memory:", "CertificateReqs") //nolint:errcheck + db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReqs") //nolint:errcheck defer db.Close() InvalidCSR := strings.ReplaceAll(ValidCSR1, "/", "+") @@ -81,9 +82,8 @@ func TestCreateFails(t *testing.T) { } func TestUpdateFails(t *testing.T) { - db := new(certdb.CertificateRequestsRepository) + db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck defer db.Close() - db.Connect(":memory:", "CertificateRequests") //nolint:errcheck db.Create(ValidCSR1) //nolint:errcheck db.Create(ValidCSR2) //nolint:errcheck @@ -97,9 +97,8 @@ func TestUpdateFails(t *testing.T) { } func TestRetrieve(t *testing.T) { - db := new(certdb.CertificateRequestsRepository) + db, _ := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") //nolint:errcheck defer db.Close() - db.Connect(":memory:", "CertificateRequests") //nolint:errcheck db.Create(ValidCSR1) //nolint:errcheck if _, err := db.Retrieve(ValidCSR2); err == nil { @@ -109,9 +108,24 @@ func TestRetrieve(t *testing.T) { } func Example() { - db := new(certdb.CertificateRequestsRepository) - if err := db.Connect("./certs.db", "CertificateReq"); err != nil { + db, err := certdb.NewCertificateRequestsRepository("./certs.db", "CertificateReq") + if err != nil { + log.Fatalln(err) + } + _, err = db.Create(ValidCSR2) + if err != nil { log.Fatalln(err) } + _, err = db.Update(ValidCSR2, ValidCert2) + if err != nil { + log.Fatalln(err) + } + entry, err := db.Retrieve(ValidCSR2) + if err != nil { + log.Fatalln(err) + } + if entry.Certificate != ValidCert2 { + log.Fatalln("Retrieved Certificate doesn't match Stored Certificate") + } defer db.Close() }