From 7057751bf3b3fa498c330ccd592dc02669641d81 Mon Sep 17 00:00:00 2001 From: saltiyazan Date: Thu, 4 Jul 2024 16:47:41 +0200 Subject: [PATCH] feat: Handlers to create and fetch accounts (#37) --- internal/api/handlers.go | 137 ++++++++++++++++++++++++++++++++-- internal/api/handlers_test.go | 122 ++++++++++++++++++++++++++++-- internal/certdb/certdb.go | 12 +-- 3 files changed, 253 insertions(+), 18 deletions(-) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 07b68e2..12a284f 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -1,15 +1,19 @@ package server import ( + "crypto/rand" "encoding/json" + "errors" "fmt" "io" "io/fs" "log" + "math/big" "net/http" "strconv" "strings" + "github.com/canonical/gocert/internal/certdb" metrics "github.com/canonical/gocert/internal/metrics" "github.com/canonical/gocert/ui" ) @@ -27,6 +31,10 @@ func NewGoCertRouter(env *Environment) http.Handler { apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env)) apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env)) + apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env)) + apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env)) + apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env)) + m := metrics.NewMetricsSubsystem(env.DB) frontendHandler := newFrontendFileServer() @@ -124,8 +132,8 @@ func GetCertificateRequest(env *Environment) http.HandlerFunc { id := r.PathValue("id") cert, err := env.DB.RetrieveCSR(id) if err != nil { - if err.Error() == "csr id not found" { - logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w) + if errors.Is(err, certdb.ErrIdNotFound) { + logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w) return } logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) @@ -149,8 +157,8 @@ func DeleteCertificateRequest(env *Environment) http.HandlerFunc { id := r.PathValue("id") insertId, err := env.DB.DeleteCSR(id) if err != nil { - if err.Error() == "csr id not found" { - logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w) + if errors.Is(err, certdb.ErrIdNotFound) { + logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w) return } logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) @@ -175,7 +183,7 @@ func PostCertificate(env *Environment) http.HandlerFunc { id := r.PathValue("id") insertId, err := env.DB.UpdateCSR(id, string(cert)) if err != nil { - if err.Error() == "csr id not found" || + if errors.Is(err, certdb.ErrIdNotFound) || err.Error() == "certificate does not match CSR" || strings.Contains(err.Error(), "cert validation failed") { logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w) @@ -203,8 +211,8 @@ func RejectCertificate(env *Environment) http.HandlerFunc { id := r.PathValue("id") insertId, err := env.DB.UpdateCSR(id, "rejected") if err != nil { - if err.Error() == "csr id not found" { - logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w) + if errors.Is(err, certdb.ErrIdNotFound) { + logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w) return } logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) @@ -231,7 +239,7 @@ func DeleteCertificate(env *Environment) http.HandlerFunc { id := r.PathValue("id") insertId, err := env.DB.UpdateCSR(id, "") if err != nil { - if err.Error() == "csr id not found" { + if errors.Is(err, certdb.ErrIdNotFound) { logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w) return } @@ -252,6 +260,106 @@ func DeleteCertificate(env *Environment) http.HandlerFunc { } } +// GetUserAccounts returns all users from the database +func GetUserAccounts(env *Environment) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + users, err := env.DB.RetrieveAllUsers() + if err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + return + } + for i := range users { + users[i].Password = "" + } + body, err := json.Marshal(users) + if err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + return + } + if _, err := w.Write(body); err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + } + } +} + +// GetUserAccount receives an id as a path parameter, and +// returns the corresponding User Account +func GetUserAccount(env *Environment) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + userAccount, err := env.DB.RetrieveUser(id) + if err != nil { + if errors.Is(err, certdb.ErrIdNotFound) { + logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w) + return + } + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + return + } + userAccount.Password = "" + body, err := json.Marshal(userAccount) + if err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + return + } + if _, err := w.Write(body); err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + } + } +} + +// PostUserAccount creates a new User Account, and returns the id of the created row +func PostUserAccount(env *Environment) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var user certdb.User + if err := json.NewDecoder(r.Body).Decode(&user); err != nil { + logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w) + return + } + if user.Username == "" { + logErrorAndWriteResponse("Username is required", http.StatusBadRequest, w) + return + } + if user.Password == "" { + generatedPassword, err := GeneratePassword(8) + if err != nil { + logErrorAndWriteResponse("Failed to generate password", http.StatusInternalServerError, w) + return + } + user.Password = generatedPassword + } + users, err := env.DB.RetrieveAllUsers() + if err != nil { + logErrorAndWriteResponse("Failed to retrieve users: "+err.Error(), http.StatusInternalServerError, w) + return + } + + permission := "0" + if len(users) == 0 { + permission = "1" //if this is the first user it will be admin + } + id, err := env.DB.CreateUser(user.Username, user.Password, permission) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed") { + logErrorAndWriteResponse("user with given username already exists", http.StatusBadRequest, w) + return + } + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + response, err := json.Marshal(map[string]any{"id": id, "password": user.Password}) + if err != nil { + logErrorAndWriteResponse("Error marshaling response", http.StatusInternalServerError, w) + } + if _, err := w.Write(response); err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + } + } +} + // logErrorAndWriteResponse is a helper function that logs any error and writes it back as an http response func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) { errMsg := fmt.Sprintf("error: %s", msg) @@ -261,3 +369,16 @@ func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) { logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) } } + +var GeneratePassword = func(length int) (string, error) { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789&*?@" + b := make([]byte, length) + for i := range b { + n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + if err != nil { + return "", err + } + b[i] = charset[n.Int64()] + } + return string(b), nil +} diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index adf9dd2..4e30dbb 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -100,7 +100,14 @@ const ( expectedGetCertReqResponseBody4 = "{\"id\":2,\"csr\":\"-----BEGIN CERTIFICATE REQUEST-----\\nMIIC5zCCAc8CAQAwRzEWMBQGA1UEAwwNMTAuMTUyLjE4My41MzEtMCsGA1UELQwk\\nMzlhY2UxOTUtZGM1YS00MzJiLTgwOTAtYWZlNmFiNGI0OWNmMIIBIjANBgkqhkiG\\n9w0BAQEFAAOCAQ8AMIIBCgKCAQEAjM5Wz+HRtDveRzeDkEDM4ornIaefe8d8nmFi\\npUat9qCU3U9798FR460DHjCLGxFxxmoRitzHtaR4ew5H036HlGB20yas/CMDgSUI\\n69DyAsyPwEJqOWBGO1LL50qXdl5/jOkO2voA9j5UsD1CtWSklyhbNhWMpYqj2ObW\\nXcaYj9Gx/TwYhw8xsJ/QRWyCrvjjVzH8+4frfDhBVOyywN7sq+I3WwCbyBBcN8uO\\nyae0b/q5+UJUiqgpeOAh/4Y7qI3YarMj4cm7dwmiCVjedUwh65zVyHtQUfLd8nFW\\nKl9775mNBc1yicvKDU3ZB5hZ1MZtpbMBwaA1yMSErs/fh5KaXwIDAQABoFswWQYJ\\nKoZIhvcNAQkOMUwwSjBIBgNVHREEQTA/hwQKmLc1gjd2YXVsdC1rOHMtMC52YXVs\\ndC1rOHMtZW5kcG9pbnRzLnZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsMA0GCSqGSIb3\\nDQEBCwUAA4IBAQCJt8oVDbiuCsik4N5AOJIT7jKsMb+j0mizwjahKMoCHdx+zv0V\\nFGkhlf0VWPAdEu3gHdJfduX88WwzJ2wBBUK38UuprAyvfaZfaYUgFJQNC6DH1fIa\\nuHYEhvNJBdFJHaBvW7lrSFi57fTA9IEPrB3m/XN3r2F4eoHnaJJqHZmMwqVHck87\\ncAQXk3fvTWuikHiCHqqdSdjDYj/8cyiwCrQWpV245VSbOE0WesWoEnSdFXVUfE1+\\nRSKeTRuuJMcdGqBkDnDI22myj0bjt7q8eqBIjTiLQLnAFnQYpcCrhc8dKU9IJlv1\\nH9Hay4ZO9LRew3pEtlx2WrExw/gpUcWM8rTI\\n-----END CERTIFICATE REQUEST-----\",\"certificate\":\"\"}" ) -func TestGoCertRouter(t *testing.T) { +const ( + adminUser = `{"username": "testadmin", "password": "admin"}` + validUser = `{"username": "testuser", "password": "user"}` + noPasswordUser = `{"username": "nopass", "password": ""}` + invalidUser = `{"username": "", "password": ""}` +) + +func TestGoCertCertificatesHandlers(t *testing.T) { testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") if err != nil { log.Fatalf("couldn't create test sqlite db: %s", err) @@ -205,16 +212,16 @@ func TestGoCertRouter(t *testing.T) { method: "DELETE", path: "/api/v1/certificate_requests/5", data: "", - response: "error: csr id not found", - status: http.StatusBadRequest, + response: "error: id not found", + status: http.StatusNotFound, }, { desc: "get csr1 fail", method: "GET", path: "/api/v1/certificate_requests/1", data: "", - response: "error: csr id not found", - status: http.StatusBadRequest, + response: "error: id not found", + status: http.StatusNotFound, }, { desc: "get csr2 success", @@ -351,3 +358,108 @@ func TestGoCertRouter(t *testing.T) { } } + +func TestGoCertUsersHandlers(t *testing.T) { + testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") + if err != nil { + log.Fatalf("couldn't create test sqlite db: %s", err) + } + env := &server.Environment{} + env.DB = testdb + ts := httptest.NewTLSServer(server.NewGoCertRouter(env)) + defer ts.Close() + + originalFunc := server.GeneratePassword + server.GeneratePassword = func(length int) (string, error) { + return "generatedPassword", nil + } + defer func() { server.GeneratePassword = originalFunc }() + + client := ts.Client() + + testCases := []struct { + desc string + method string + path string + data string + response string + status int + }{ + { + desc: "Create first user success", + method: "POST", + path: "/api/v1/accounts", + data: adminUser, + response: "{\"id\":1,\"password\":\"admin\"}", + status: http.StatusCreated, + }, + { + desc: "Retrieve admin user success", + method: "GET", + path: "/api/v1/accounts/1", + data: "", + response: "{\"id\":1,\"username\":\"testadmin\",\"permissions\":1}", + status: http.StatusOK, + }, + { + desc: "Create second user success", + method: "POST", + path: "/api/v1/accounts", + data: validUser, + response: "{\"id\":2,\"password\":\"user\"}", + status: http.StatusCreated, + }, + { + desc: "Create no password user success", + method: "POST", + path: "/api/v1/accounts", + data: noPasswordUser, + response: "{\"id\":3,\"password\":\"generatedPassword\"}", + status: http.StatusCreated, + }, + { + desc: "Retrieve normal user success", + method: "GET", + path: "/api/v1/accounts/2", + data: "", + response: "{\"id\":2,\"username\":\"testuser\",\"permissions\":0}", + status: http.StatusOK, + }, + { + desc: "Retrieve user failure", + method: "GET", + path: "/api/v1/accounts/300", + data: "", + response: "error: id not found", + status: http.StatusNotFound, + }, + { + desc: "Create user failure", + method: "POST", + path: "/api/v1/accounts", + data: invalidUser, + response: "error: Username is required", + status: http.StatusBadRequest, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + req, err := http.NewRequest(tC.method, ts.URL+tC.path, strings.NewReader(tC.data)) + if err != nil { + t.Fatal(err) + } + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + resBody, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tC.status || !strings.Contains(string(resBody), tC.response) { + t.Errorf("expected response did not match.\nExpected vs Received status code: %d vs %d\nExpected vs Received body: \n%s\nvs\n%s\n", tC.status, res.StatusCode, tC.response, string(resBody)) + } + }) + } +} diff --git a/internal/certdb/certdb.go b/internal/certdb/certdb.go index 0055de9..0d306b3 100644 --- a/internal/certdb/certdb.go +++ b/internal/certdb/certdb.go @@ -53,10 +53,12 @@ type CertificateRequest struct { type User struct { ID int `json:"id"` Username string `json:"username"` - Password string `json:"password"` + Password string `json:"password,omitempty"` Permissions int `json:"permissions"` } +var ErrIdNotFound = errors.New("id not found") + // RetrieveAllCSRs gets every CertificateRequest entry in the table. func (db *CertificateRequestsRepository) RetrieveAllCSRs() ([]CertificateRequest, error) { rows, err := db.conn.Query(fmt.Sprintf(queryGetAllCSRs, db.certificateTable)) @@ -83,7 +85,7 @@ func (db *CertificateRequestsRepository) RetrieveCSR(id string) (CertificateRequ row := db.conn.QueryRow(fmt.Sprintf(queryGetCSR, db.certificateTable), id) if err := row.Scan(&newCSR.ID, &newCSR.CSR, &newCSR.Certificate); err != nil { if err.Error() == "sql: no rows in result set" { - return newCSR, errors.New("csr id not found") + return newCSR, ErrIdNotFound } return newCSR, err } @@ -146,7 +148,7 @@ func (db *CertificateRequestsRepository) DeleteCSR(id string) (int64, error) { return 0, err } if deleteId == 0 { - return 0, errors.New("csr id not found") + return 0, ErrIdNotFound } return deleteId, nil } @@ -176,7 +178,7 @@ func (db *CertificateRequestsRepository) RetrieveUser(id string) (User, error) { row := db.conn.QueryRow(queryGetUser, id) if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil { if err.Error() == "sql: no rows in result set" { - return newUser, errors.New("user id not found") + return newUser, ErrIdNotFound } return newUser, err } @@ -235,7 +237,7 @@ func (db *CertificateRequestsRepository) DeleteUser(id string) (int64, error) { return 0, err } if deleteId == 0 { - return 0, errors.New("user id not found") + return 0, ErrIdNotFound } return deleteId, nil }