Skip to content

Commit

Permalink
feat: GoCert 1.0 Handlers (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
kayra1 committed Apr 24, 2024
1 parent bd0a096 commit f82ae0a
Show file tree
Hide file tree
Showing 9 changed files with 640 additions and 96 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
*.db
*.db
*.pem
*config.yaml

.DS_Store
13 changes: 2 additions & 11 deletions cmd/gocert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import (
"log"
"os"

server "github.com/canonical/gocert/api"
"github.com/canonical/gocert/internal/certdb"
server "github.com/canonical/gocert/internal/api"
)

func main() {
Expand All @@ -17,15 +16,7 @@ func main() {
if *configFilePtr == "" {
log.Fatalf("Providing a valid config file is required.")
}
config, err := server.ValidateConfigFile(*configFilePtr)
if err != nil {
log.Fatalf("Config file validation failed: %s.", err)
}
_, err = certdb.NewCertificateRequestsRepository(config.DBPath, "CertificateRequests")
if err != nil {
log.Fatalf("Couldn't connect to database: %s", err)
}
srv, err := server.NewServer(config.Cert, config.Key, config.Port)
srv, err := server.NewServer(*configFilePtr)
if err != nil {
log.Fatalf("Couldn't create server: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/gocert/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func TestGoCertFail(t *testing.T) {
ExpectedOutput string
}{
{"flags not set", []string{}, validConfig, "Providing a valid config file is required."},
{"config file not valid", []string{"-config", "config.yaml"}, invalidConfig, "Config file validation failed:"},
{"config file not valid", []string{"-config", "config.yaml"}, invalidConfig, "config file validation failed:"},
{"database not connectable", []string{"-config", "config.yaml"}, invalidDBConfig, "Couldn't connect to database:"},
}
for _, tc := range cases {
Expand Down
194 changes: 194 additions & 0 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package server

import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
)

// NewGoCertRouter takes in an environment struct, passes it along to any handlers that will need
// access to it, then builds and returns it for a server to consume
func NewGoCertRouter(env *Environment) http.Handler {
router := http.NewServeMux()
router.HandleFunc("GET /certificate_requests", GetCertificateRequests(env))
router.HandleFunc("POST /certificate_requests", PostCertificateRequest(env))
router.HandleFunc("GET /certificate_requests/{id}", GetCertificateRequest(env))
router.HandleFunc("DELETE /certificate_requests/{id}", DeleteCertificateRequest(env))
router.HandleFunc("POST /certificate_requests/{id}/certificate", PostCertificate(env))
router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env))

v1 := http.NewServeMux()
v1.HandleFunc("GET /status", HealthCheck)
v1.Handle("/api/v1/", http.StripPrefix("/api/v1", router))

return logging(v1)
}

// the health check endpoint simply returns a http.StatusOK
func HealthCheck(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) //nolint:errcheck
}

// GetCertificateRequests returns all of the Certificate Requests
func GetCertificateRequests(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
certs, err := env.DB.RetrieveAll()
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
body, err := json.Marshal(certs)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
if _, err := w.Write(body); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// PostCertificateRequest creates a new Certificate Request, and returns the id of the created row
func PostCertificateRequest(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
csr, err := io.ReadAll(r.Body)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
id, err := env.DB.Create(string(csr))
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
logErrorAndWriteResponse("given csr already recorded", http.StatusBadRequest, w)
return
}
if strings.Contains(err.Error(), "csr validation failed") {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
w.WriteHeader(http.StatusCreated)
if _, err := w.Write([]byte(strconv.FormatInt(id, 10))); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// GetCertificateRequests receives an id as a path parameter, and
// returns the corresponding Certificate Request
func GetCertificateRequest(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
cert, err := env.DB.Retrieve(id)
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
body, err := json.Marshal(cert)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
if _, err := w.Write(body); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// DeleteCertificateRequest handler receives an id as a path parameter,
// deletes the corresponding Certificate Request, and returns a http.StatusNoContent on success
func DeleteCertificateRequest(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
insertId, err := env.DB.Delete(id)
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
w.WriteHeader(http.StatusAccepted)
if _, err := w.Write([]byte(strconv.FormatInt(insertId, 10))); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// PostCertificate handler receives an id as a path parameter,
// and attempts to add a given certificate to the corresponding certificate request
func PostCertificate(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cert, err := io.ReadAll(r.Body)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
return
}
id := r.PathValue("id")
insertId, err := env.DB.Update(id, string(cert))
if err != nil {
if err.Error() == "csr id not found" ||
err.Error() == "certificate does not match CSR" ||
strings.Contains(err.Error(), "cert validation failed") {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
w.WriteHeader(http.StatusCreated)
if _, err := w.Write([]byte(strconv.FormatInt(insertId, 10))); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// DeleteCertificate handler receives an id as a path parameter,
// and attempts to add a given certificate to the corresponding certificate request
func DeleteCertificate(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
insertId, err := env.DB.Update(id, "")
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
w.WriteHeader(http.StatusAccepted)
if _, err := w.Write([]byte(strconv.FormatInt(insertId, 10))); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// The logging middleware captures any http request coming through, and logs it
func logging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
log.Println(r.Method, r.URL.Path)
})
}

// logErrorAndWriteResponse is a helper function that logs any error and writes it back as an http response
func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) {
errMsg := fmt.Sprintf("error: %s", msg)
log.Println(errMsg)
w.WriteHeader(status)
if _, err := w.Write([]byte(errMsg)); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
Loading

0 comments on commit f82ae0a

Please sign in to comment.