Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kayra1 committed Apr 1, 2024
1 parent de22f27 commit efdfbfd
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
20 changes: 10 additions & 10 deletions internal/certdb/certdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@ import (
_ "github.com/mattn/go-sqlite3"
)

const queryCreateTable = "CREATE TABLE IF NOT EXISTS %s (CSR VARCHAR PRIMARY KEY UNIQUE NOT NULL, Certificate VARCHAR)"
const queryCreateTable = "CREATE TABLE IF NOT EXISTS %s (CSR VARCHAR PRIMARY KEY UNIQUE NOT NULL, Certificate VARCHAR DEFAULT '')"

const queryGetAllCSRs = "SELECT rowid, * FROM %s"
const queryGetCSR = "SELECT rowid, * FROM %s WHERE CSR=?"
const queryCreateCSR = "INSERT INTO %s (CSR) VALUES (?)"
const queryUpdateCSR = "UPDATE %s SET Certificate=? WHERE CSR=?"
const queryDeleteCSR = "DELETE FROM %s WHERE CSR=?"

type CertificateRequests struct {
type CertificateRequestsRepository struct {
table string
conn *sql.DB
}

type CertificateRequest struct {
ID int
CSR string
Certificate *string
Certificate string
}

func (db *CertificateRequests) RetrieveAll() ([]CertificateRequest, error) {
func (db *CertificateRequestsRepository) RetrieveAll() ([]CertificateRequest, error) {
rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.table))
if err != nil {
return nil, err
Expand All @@ -44,7 +44,7 @@ func (db *CertificateRequests) RetrieveAll() ([]CertificateRequest, error) {
return allCsrs, nil
}

func (db *CertificateRequests) Retrieve(csr string) (*CertificateRequest, error) {
func (db *CertificateRequestsRepository) Retrieve(csr string) (*CertificateRequest, error) {
var newCSR CertificateRequest
row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.table), csr)
if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil {
Expand All @@ -53,7 +53,7 @@ func (db *CertificateRequests) Retrieve(csr string) (*CertificateRequest, error)
return &newCSR, nil
}

func (db *CertificateRequests) Create(csr string) (int64, error) {
func (db *CertificateRequestsRepository) Create(csr string) (int64, error) {
if err := ValidateCertificateRequest(csr); err != nil {
return 0, err
}
Expand All @@ -68,7 +68,7 @@ func (db *CertificateRequests) Create(csr string) (int64, error) {
return id, nil
}

func (db *CertificateRequests) Update(csr string, cert string) (int64, error) {
func (db *CertificateRequestsRepository) Update(csr string, cert string) (int64, error) {
if err := ValidateCertificate(cert, csr); err != nil {
return 0, err
}
Expand All @@ -83,15 +83,15 @@ func (db *CertificateRequests) Update(csr string, cert string) (int64, error) {
return id, nil
}

func (db *CertificateRequests) Delete(csr string) error {
func (db *CertificateRequestsRepository) Delete(csr string) error {
_, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.table), csr)
if err != nil {
return err
}
return nil
}

func (db *CertificateRequests) Connect(databasePath string, tableName string) error {
func (db *CertificateRequestsRepository) Connect(databasePath string, tableName string) error {
conn, err := sql.Open("sqlite3", databasePath)
if err != nil {
return err
Expand All @@ -104,7 +104,7 @@ func (db *CertificateRequests) Connect(databasePath string, tableName string) er
return nil
}

func (db *CertificateRequests) Disconnect() error {
func (db *CertificateRequestsRepository) Close() error {
if db.conn == nil {
return nil
}
Expand Down
44 changes: 22 additions & 22 deletions internal/certdb/certdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ import (
)

func TestConnect(t *testing.T) {
db := new(certdb.CertificateRequests)
defer db.Disconnect()
db := new(certdb.CertificateRequestsRepository)
defer db.Close()
if err := db.Connect(":memory:", "CertificateReqs"); err != nil {
t.Fatalf("Can't connect to SQLite: %s", err)
}
}

func TestEndToEnd(t *testing.T) {
db := new(certdb.CertificateRequests)
defer db.Disconnect()
db.Connect(":memory:", "CertificateRequests")
db := new(certdb.CertificateRequestsRepository)
defer db.Close()
db.Connect(":memory:", "CertificateRequests") //nolint:errcheck

if _, err := db.Create(ValidCSR1); err != nil {
t.Fatalf("Couldn't complete Create: %s", err)
Expand Down Expand Up @@ -59,34 +59,34 @@ func TestEndToEnd(t *testing.T) {
t.Fatalf("Couldn't complete Update: %s", err)
}
retrievedCSR, _ = db.Retrieve(ValidCSR2)
if *retrievedCSR.Certificate != ValidCert2 {
t.Fatalf("The certificate that was uploaded does not match the certificate that was given: Retrieved: %s\nGiven: %s", *retrievedCSR.Certificate, ValidCert2)
if retrievedCSR.Certificate != ValidCert2 {
t.Fatalf("The certificate that was uploaded does not match the certificate that was given: Retrieved: %s\nGiven: %s", retrievedCSR.Certificate, ValidCert2)
}
}

func TestCreateFails(t *testing.T) {
db := new(certdb.CertificateRequests)
db.Connect(":memory:", "CertificateReqs")
defer db.Disconnect()
db := new(certdb.CertificateRequestsRepository)
db.Connect(":memory:", "CertificateReqs") //nolint:errcheck
defer db.Close()

InvalidCSR := strings.ReplaceAll(ValidCSR1, "/", "+")
if _, err := db.Create(InvalidCSR); err == nil {
t.Fatalf("Expected error due to invalid CSR")
}

db.Create(ValidCSR1)
db.Create(ValidCSR1) //nolint:errcheck
if _, err := db.Create(ValidCSR1); err == nil {
t.Fatalf("Expected error due to duplicate CSR")
}
}

func TestUpdateFails(t *testing.T) {
db := new(certdb.CertificateRequests)
defer db.Disconnect()
db.Connect(":memory:", "CertificateRequests")
db := new(certdb.CertificateRequestsRepository)
defer db.Close()
db.Connect(":memory:", "CertificateRequests") //nolint:errcheck

db.Create(ValidCSR1)
db.Create(ValidCSR2)
db.Create(ValidCSR1) //nolint:errcheck
db.Create(ValidCSR2) //nolint:errcheck
InvalidCert := strings.ReplaceAll(ValidCert2, "/", "+")
if _, err := db.Update(ValidCSR2, InvalidCert); err == nil {
t.Fatalf("Expected updating with invalid cert to fail")
Expand All @@ -97,21 +97,21 @@ func TestUpdateFails(t *testing.T) {
}

func TestRetrieve(t *testing.T) {
db := new(certdb.CertificateRequests)
defer db.Disconnect()
db.Connect(":memory:", "CertificateRequests")
db := new(certdb.CertificateRequestsRepository)
defer db.Close()
db.Connect(":memory:", "CertificateRequests") //nolint:errcheck

db.Create(ValidCSR1)
db.Create(ValidCSR1) //nolint:errcheck
if _, err := db.Retrieve(ValidCSR2); err == nil {
t.Fatalf("Expected failure looking for nonexistent CSR")
}

}

func Example() {
db := new(certdb.CertificateRequests)
db := new(certdb.CertificateRequestsRepository)
if err := db.Connect("./certs.db", "CertificateReq"); err != nil {
log.Fatalln(err)
}
defer db.Disconnect()
defer db.Close()
}

0 comments on commit efdfbfd

Please sign in to comment.