Skip to content

Commit

Permalink
new sql queries without pw hashing or tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kayra1 committed Jul 1, 2024
1 parent d26cc03 commit 13aaae2
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 71 deletions.
14 changes: 7 additions & 7 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func HealthCheck(w http.ResponseWriter, r *http.Request) {
// GetCertificateRequests returns all of the Certificate Requests
func GetCertificateRequests(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
certs, err := env.DB.RetrieveAll()
certs, err := env.DB.RetrieveAllCSRs()
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
Expand All @@ -97,7 +97,7 @@ func PostCertificateRequest(env *Environment) http.HandlerFunc {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
id, err := env.DB.Create(string(csr))
id, err := env.DB.CreateCSR(string(csr))
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
logErrorAndWriteResponse("given csr already recorded", http.StatusBadRequest, w)
Expand All @@ -122,7 +122,7 @@ func PostCertificateRequest(env *Environment) http.HandlerFunc {
func GetCertificateRequest(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
cert, err := env.DB.Retrieve(id)
cert, err := env.DB.RetrieveCSR(id)
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
Expand All @@ -147,7 +147,7 @@ func GetCertificateRequest(env *Environment) http.HandlerFunc {
func DeleteCertificateRequest(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
insertId, err := env.DB.Delete(id)
insertId, err := env.DB.DeleteCSR(id)
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
Expand All @@ -173,7 +173,7 @@ func PostCertificate(env *Environment) http.HandlerFunc {
return
}
id := r.PathValue("id")
insertId, err := env.DB.Update(id, string(cert))
insertId, err := env.DB.UpdateCSR(id, string(cert))
if err != nil {
if err.Error() == "csr id not found" ||
err.Error() == "certificate does not match CSR" ||
Expand Down Expand Up @@ -201,7 +201,7 @@ func PostCertificate(env *Environment) http.HandlerFunc {
func RejectCertificate(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
insertId, err := env.DB.Update(id, "rejected")
insertId, err := env.DB.UpdateCSR(id, "rejected")
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
Expand Down Expand Up @@ -229,7 +229,7 @@ func RejectCertificate(env *Environment) http.HandlerFunc {
func DeleteCertificate(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
insertId, err := env.DB.Update(id, "")
insertId, err := env.DB.UpdateCSR(id, "")
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
Expand Down
153 changes: 127 additions & 26 deletions internal/certdb/certdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,38 @@ import (
_ "github.com/mattn/go-sqlite3"
)

const queryCreateTable = "CREATE TABLE IF NOT EXISTS %s (CSR VARCHAR PRIMARY KEY UNIQUE NOT NULL, Certificate VARCHAR DEFAULT '')"
const queryCreateCSRsTable = `CREATE TABLE IF NOT EXISTS %s (
csr TEXT PRIMARY KEY UNIQUE NOT NULL,
certificate TEXT DEFAULT ''
)`

const queryGetAllCSRs = "SELECT rowid, * FROM %s"
const queryGetCSR = "SELECT rowid, * FROM %s WHERE rowid=?"
const queryCreateCSR = "INSERT INTO %s (CSR) VALUES (?)"
const queryUpdateCSR = "UPDATE %s SET Certificate=? WHERE rowid=?"
const queryDeleteCSR = "DELETE FROM %s WHERE rowid=?"
const (
queryGetAllCSRs = "SELECT rowid, * FROM %s"
queryGetCSR = "SELECT rowid, * FROM %s WHERE rowid=?"
queryCreateCSR = "INSERT INTO %s (csr) VALUES (?)"
queryUpdateCSR = "UPDATE %s SET certificate=? WHERE rowid=?"
queryDeleteCSR = "DELETE FROM %s WHERE rowid=?"
)

const queryCreateUsersTable = `CREATE TABLE IF NOT EXISTS users (

Check failure on line 25 in internal/certdb/certdb.go

View workflow job for this annotation

GitHub Actions / go-lint / lint

const `queryCreateUsersTable` is unused (unused)
user_id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL,
salt TEXT NOT NULL,
permissions INTEGER,
)`
const (
queryGetAllUsers = "SELECT * FROM users"
queryGetUser = "SELECT * FROM users WHERE user_id=?"
queryCreateUser = "INSERT INTO users (username, password, salt, permissions) VALUES (?, ?, ?, ?)"
queryUpdateUser = "UPDATE users SET password=?, salt=? WHERE user_id=?"
queryDeleteUser = "DELETE FROM users WHERE user_id=?"
)

// CertificateRequestRepository is the object used to communicate with the established repository.
type CertificateRequestsRepository struct {
table string
conn *sql.DB
certificateTable string
conn *sql.DB
}

// A CertificateRequest struct represents an entry in the database.
Expand All @@ -30,10 +50,17 @@ type CertificateRequest struct {
CSR string
Certificate string
}
type User struct {
ID int
Username string
Password string
Salt string
Permissions int
}

// RetrieveAll gets every CertificateRequest entry in the table.
func (db *CertificateRequestsRepository) RetrieveAll() ([]CertificateRequest, error) {
rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.table))
// RetrieveAllCSRs gets every CertificateRequest entry in the table.
func (db *CertificateRequestsRepository) RetrieveAllCSRs() ([]CertificateRequest, error) {
rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.certificateTable))
if err != nil {
return nil, err
}
Expand All @@ -50,11 +77,11 @@ func (db *CertificateRequestsRepository) RetrieveAll() ([]CertificateRequest, er
return allCsrs, nil
}

// Retrieve gets a given CSR from the repository.
// RetrieveCSR gets a given CSR from the repository.
// It returns the row id and matching certificate alongside the CSR in a CertificateRequest object.
func (db *CertificateRequestsRepository) Retrieve(id string) (CertificateRequest, error) {
func (db *CertificateRequestsRepository) RetrieveCSR(id string) (CertificateRequest, error) {
var newCSR CertificateRequest
row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.table), id)
row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.certificateTable), id)
if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil {
if err.Error() == "sql: no rows in result set" {
return newCSR, errors.New("csr id not found")
Expand All @@ -64,13 +91,13 @@ func (db *CertificateRequestsRepository) Retrieve(id string) (CertificateRequest
return newCSR, nil
}

// Create creates a new entry in the repository.
// CreateCSR creates a new entry in the repository.
// The given CSR must be valid and unique
func (db *CertificateRequestsRepository) Create(csr string) (int64, error) {
func (db *CertificateRequestsRepository) CreateCSR(csr string) (int64, error) {
if err := ValidateCertificateRequest(csr); err != nil {
return 0, errors.New("csr validation failed: " + err.Error())
}
result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.table), csr)
result, err := db.conn.Exec(fmt.Sprintf(queryCreateCSR, db.certificateTable), csr)
if err != nil {
return 0, err
}
Expand All @@ -81,10 +108,10 @@ func (db *CertificateRequestsRepository) Create(csr string) (int64, error) {
return id, nil
}

// Update adds a new cert to the given CSR in the repository.
// UpdateCSR adds a new cert to the given CSR in the repository.
// The given certificate must share the public key of the CSR and must be valid.
func (db *CertificateRequestsRepository) Update(id string, cert string) (int64, error) {
csr, err := db.Retrieve(id)
func (db *CertificateRequestsRepository) UpdateCSR(id string, cert string) (int64, error) {
csr, err := db.RetrieveCSR(id)
if err != nil {
return 0, err
}
Expand All @@ -98,7 +125,7 @@ func (db *CertificateRequestsRepository) Update(id string, cert string) (int64,
return 0, errors.New("cert validation failed: " + err.Error())
}
}
result, err := db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.table), cert, csr.ID)
result, err := db.conn.Exec(fmt.Sprintf(queryUpdateCSR, db.certificateTable), cert, csr.ID)
if err != nil {
return 0, err
}
Expand All @@ -109,9 +136,9 @@ func (db *CertificateRequestsRepository) Update(id string, cert string) (int64,
return insertId, nil
}

// Delete removes a CSR from the database alongside the certificate that may have been generated for it.
func (db *CertificateRequestsRepository) Delete(id string) (int64, error) {
result, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.table), id)
// DeleteCSR removes a CSR from the database alongside the certificate that may have been generated for it.
func (db *CertificateRequestsRepository) DeleteCSR(id string) (int64, error) {
result, err := db.conn.Exec(fmt.Sprintf(queryDeleteCSR, db.certificateTable), id)
if err != nil {
return 0, err
}
Expand All @@ -125,6 +152,80 @@ func (db *CertificateRequestsRepository) Delete(id string) (int64, error) {
return deleteId, nil
}

func (db *CertificateRequestsRepository) RetrieveAllUsers() ([]User, error) {
rows, err := db.conn.Query(queryGetAllUsers)
if err != nil {
return nil, err
}

var allUsers []User
defer rows.Close()
for rows.Next() {
var user User
if err := rows.Scan(&user.ID, &user.Username, &user.Password, &user.Salt, &user.Permissions); err != nil {
return nil, err
}
allUsers = append(allUsers, user)
}
return allUsers, nil
}
func (db *CertificateRequestsRepository) RetrieveUser(id string) (User, error) {
var newUser User
row := db.conn.QueryRow(queryGetUser, id)
if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Salt, &newUser.Permissions); err != nil {
if err.Error() == "sql: no rows in result set" {
return newUser, errors.New("user id not found")
}
return newUser, err
}
return newUser, nil
}

// func (db *CertificateRequestsRepository) CreateUser(username, password, permissions string) (int64, error) {
// passwordHash, salt, err := something bcrypt
// result, err := db.conn.Exec(queryCreateUser, username, passwordHash, salt, permissions)
// if err != nil {
// return 0, err
// }
// id, err := result.LastInsertId()
// if err != nil {
// return 0, err
// }
// return id, nil
// }

func (db *CertificateRequestsRepository) UpdateUser(id, password string) (int64, error) {
user, err := db.RetrieveUser(id)
if err != nil {
return 0, err
}
// passwordHash, salt := something bcrypt
result, err := db.conn.Exec(queryUpdateUser, user.ID)
if err != nil {
return 0, err
}
insertId, err := result.LastInsertId()
if err != nil {
return 0, err
}
return insertId, nil
}

func (db *CertificateRequestsRepository) DeleteUser(id string) (int64, error) {
result, err := db.conn.Exec(queryDeleteCSR, id)
if err != nil {
return 0, err
}
deleteId, err := result.RowsAffected()
if err != nil {
return 0, err
}
if deleteId == 0 {
return 0, errors.New("user id not found")
}
return deleteId, nil
}

// Close closes the connection to the repository cleanly.
func (db *CertificateRequestsRepository) Close() error {
if db.conn == nil {
Expand All @@ -145,11 +246,11 @@ func NewCertificateRequestsRepository(databasePath string, tableName string) (*C
if err != nil {
return nil, err
}
if _, err := conn.Exec(fmt.Sprintf(queryCreateTable, tableName)); err != nil {
if _, err := conn.Exec(fmt.Sprintf(queryCreateCSRsTable, tableName)); err != nil {
return nil, err
}
db := new(CertificateRequestsRepository)
db.conn = conn
db.table = tableName
db.certificateTable = tableName
return db, nil
}
Loading

0 comments on commit 13aaae2

Please sign in to comment.