Skip to content

Commit

Permalink
feat: Handlers to create and fetch accounts (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
saltiyazan committed Jul 4, 2024
1 parent 4347b8d commit 7057751
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 18 deletions.
137 changes: 129 additions & 8 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package server

import (
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"math/big"
"net/http"
"strconv"
"strings"

"github.com/canonical/gocert/internal/certdb"
metrics "github.com/canonical/gocert/internal/metrics"
"github.com/canonical/gocert/ui"
)
Expand All @@ -27,6 +31,10 @@ func NewGoCertRouter(env *Environment) http.Handler {
apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env))
apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env))

apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env))
apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env))
apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env))

m := metrics.NewMetricsSubsystem(env.DB)
frontendHandler := newFrontendFileServer()

Expand Down Expand Up @@ -124,8 +132,8 @@ func GetCertificateRequest(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
cert, err := env.DB.RetrieveCSR(id)
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
Expand All @@ -149,8 +157,8 @@ func DeleteCertificateRequest(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
insertId, err := env.DB.DeleteCSR(id)
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
Expand All @@ -175,7 +183,7 @@ func PostCertificate(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
insertId, err := env.DB.UpdateCSR(id, string(cert))
if err != nil {
if err.Error() == "csr id not found" ||
if errors.Is(err, certdb.ErrIdNotFound) ||
err.Error() == "certificate does not match CSR" ||
strings.Contains(err.Error(), "cert validation failed") {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
Expand Down Expand Up @@ -203,8 +211,8 @@ func RejectCertificate(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
insertId, err := env.DB.UpdateCSR(id, "rejected")
if err != nil {
if err.Error() == "csr id not found" {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
Expand All @@ -231,7 +239,7 @@ func DeleteCertificate(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
insertId, err := env.DB.UpdateCSR(id, "")
if err != nil {
if err.Error() == "csr id not found" {
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
return
}
Expand All @@ -252,6 +260,106 @@ func DeleteCertificate(env *Environment) http.HandlerFunc {
}
}

// GetUserAccounts returns all users from the database
func GetUserAccounts(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
users, err := env.DB.RetrieveAllUsers()
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
for i := range users {
users[i].Password = ""
}
body, err := json.Marshal(users)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
if _, err := w.Write(body); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// GetUserAccount receives an id as a path parameter, and
// returns the corresponding User Account
func GetUserAccount(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
userAccount, err := env.DB.RetrieveUser(id)
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
userAccount.Password = ""
body, err := json.Marshal(userAccount)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
if _, err := w.Write(body); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// PostUserAccount creates a new User Account, and returns the id of the created row
func PostUserAccount(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var user certdb.User
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
return
}
if user.Username == "" {
logErrorAndWriteResponse("Username is required", http.StatusBadRequest, w)
return
}
if user.Password == "" {
generatedPassword, err := GeneratePassword(8)
if err != nil {
logErrorAndWriteResponse("Failed to generate password", http.StatusInternalServerError, w)
return
}
user.Password = generatedPassword
}
users, err := env.DB.RetrieveAllUsers()
if err != nil {
logErrorAndWriteResponse("Failed to retrieve users: "+err.Error(), http.StatusInternalServerError, w)
return
}

permission := "0"
if len(users) == 0 {
permission = "1" //if this is the first user it will be admin
}
id, err := env.DB.CreateUser(user.Username, user.Password, permission)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
logErrorAndWriteResponse("user with given username already exists", http.StatusBadRequest, w)
return
}
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
response, err := json.Marshal(map[string]any{"id": id, "password": user.Password})
if err != nil {
logErrorAndWriteResponse("Error marshaling response", http.StatusInternalServerError, w)
}
if _, err := w.Write(response); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// logErrorAndWriteResponse is a helper function that logs any error and writes it back as an http response
func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) {
errMsg := fmt.Sprintf("error: %s", msg)
Expand All @@ -261,3 +369,16 @@ func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}

var GeneratePassword = func(length int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789&*?@"
b := make([]byte, length)
for i := range b {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
b[i] = charset[n.Int64()]
}
return string(b), nil
}
122 changes: 117 additions & 5 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,14 @@ const (
expectedGetCertReqResponseBody4 = "{\"id\":2,\"csr\":\"-----BEGIN CERTIFICATE REQUEST-----\\nMIIC5zCCAc8CAQAwRzEWMBQGA1UEAwwNMTAuMTUyLjE4My41MzEtMCsGA1UELQwk\\nMzlhY2UxOTUtZGM1YS00MzJiLTgwOTAtYWZlNmFiNGI0OWNmMIIBIjANBgkqhkiG\\n9w0BAQEFAAOCAQ8AMIIBCgKCAQEAjM5Wz+HRtDveRzeDkEDM4ornIaefe8d8nmFi\\npUat9qCU3U9798FR460DHjCLGxFxxmoRitzHtaR4ew5H036HlGB20yas/CMDgSUI\\n69DyAsyPwEJqOWBGO1LL50qXdl5/jOkO2voA9j5UsD1CtWSklyhbNhWMpYqj2ObW\\nXcaYj9Gx/TwYhw8xsJ/QRWyCrvjjVzH8+4frfDhBVOyywN7sq+I3WwCbyBBcN8uO\\nyae0b/q5+UJUiqgpeOAh/4Y7qI3YarMj4cm7dwmiCVjedUwh65zVyHtQUfLd8nFW\\nKl9775mNBc1yicvKDU3ZB5hZ1MZtpbMBwaA1yMSErs/fh5KaXwIDAQABoFswWQYJ\\nKoZIhvcNAQkOMUwwSjBIBgNVHREEQTA/hwQKmLc1gjd2YXVsdC1rOHMtMC52YXVs\\ndC1rOHMtZW5kcG9pbnRzLnZhdWx0LnN2Yy5jbHVzdGVyLmxvY2FsMA0GCSqGSIb3\\nDQEBCwUAA4IBAQCJt8oVDbiuCsik4N5AOJIT7jKsMb+j0mizwjahKMoCHdx+zv0V\\nFGkhlf0VWPAdEu3gHdJfduX88WwzJ2wBBUK38UuprAyvfaZfaYUgFJQNC6DH1fIa\\nuHYEhvNJBdFJHaBvW7lrSFi57fTA9IEPrB3m/XN3r2F4eoHnaJJqHZmMwqVHck87\\ncAQXk3fvTWuikHiCHqqdSdjDYj/8cyiwCrQWpV245VSbOE0WesWoEnSdFXVUfE1+\\nRSKeTRuuJMcdGqBkDnDI22myj0bjt7q8eqBIjTiLQLnAFnQYpcCrhc8dKU9IJlv1\\nH9Hay4ZO9LRew3pEtlx2WrExw/gpUcWM8rTI\\n-----END CERTIFICATE REQUEST-----\",\"certificate\":\"\"}"
)

func TestGoCertRouter(t *testing.T) {
const (
adminUser = `{"username": "testadmin", "password": "admin"}`
validUser = `{"username": "testuser", "password": "user"}`
noPasswordUser = `{"username": "nopass", "password": ""}`
invalidUser = `{"username": "", "password": ""}`
)

func TestGoCertCertificatesHandlers(t *testing.T) {
testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests")
if err != nil {
log.Fatalf("couldn't create test sqlite db: %s", err)
Expand Down Expand Up @@ -205,16 +212,16 @@ func TestGoCertRouter(t *testing.T) {
method: "DELETE",
path: "/api/v1/certificate_requests/5",
data: "",
response: "error: csr id not found",
status: http.StatusBadRequest,
response: "error: id not found",
status: http.StatusNotFound,
},
{
desc: "get csr1 fail",
method: "GET",
path: "/api/v1/certificate_requests/1",
data: "",
response: "error: csr id not found",
status: http.StatusBadRequest,
response: "error: id not found",
status: http.StatusNotFound,
},
{
desc: "get csr2 success",
Expand Down Expand Up @@ -351,3 +358,108 @@ func TestGoCertRouter(t *testing.T) {
}

}

func TestGoCertUsersHandlers(t *testing.T) {
testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests")
if err != nil {
log.Fatalf("couldn't create test sqlite db: %s", err)
}
env := &server.Environment{}
env.DB = testdb
ts := httptest.NewTLSServer(server.NewGoCertRouter(env))
defer ts.Close()

originalFunc := server.GeneratePassword
server.GeneratePassword = func(length int) (string, error) {
return "generatedPassword", nil
}
defer func() { server.GeneratePassword = originalFunc }()

client := ts.Client()

testCases := []struct {
desc string
method string
path string
data string
response string
status int
}{
{
desc: "Create first user success",
method: "POST",
path: "/api/v1/accounts",
data: adminUser,
response: "{\"id\":1,\"password\":\"admin\"}",
status: http.StatusCreated,
},
{
desc: "Retrieve admin user success",
method: "GET",
path: "/api/v1/accounts/1",
data: "",
response: "{\"id\":1,\"username\":\"testadmin\",\"permissions\":1}",
status: http.StatusOK,
},
{
desc: "Create second user success",
method: "POST",
path: "/api/v1/accounts",
data: validUser,
response: "{\"id\":2,\"password\":\"user\"}",
status: http.StatusCreated,
},
{
desc: "Create no password user success",
method: "POST",
path: "/api/v1/accounts",
data: noPasswordUser,
response: "{\"id\":3,\"password\":\"generatedPassword\"}",
status: http.StatusCreated,
},
{
desc: "Retrieve normal user success",
method: "GET",
path: "/api/v1/accounts/2",
data: "",
response: "{\"id\":2,\"username\":\"testuser\",\"permissions\":0}",
status: http.StatusOK,
},
{
desc: "Retrieve user failure",
method: "GET",
path: "/api/v1/accounts/300",
data: "",
response: "error: id not found",
status: http.StatusNotFound,
},
{
desc: "Create user failure",
method: "POST",
path: "/api/v1/accounts",
data: invalidUser,
response: "error: Username is required",
status: http.StatusBadRequest,
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
req, err := http.NewRequest(tC.method, ts.URL+tC.path, strings.NewReader(tC.data))
if err != nil {
t.Fatal(err)
}
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
resBody, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
}
if res.StatusCode != tC.status || !strings.Contains(string(resBody), tC.response) {
t.Errorf("expected response did not match.\nExpected vs Received status code: %d vs %d\nExpected vs Received body: \n%s\nvs\n%s\n", tC.status, res.StatusCode, tC.response, string(resBody))
}
})
}
}
Loading

0 comments on commit 7057751

Please sign in to comment.