From 1d221b2eb32b91e515588a9f3dd851b3e94ebab7 Mon Sep 17 00:00:00 2001 From: saltiyazan Date: Tue, 9 Jul 2024 11:40:14 +0200 Subject: [PATCH] feat: Adds login endpoint and handlers (#40) --- go.mod | 1 + go.sum | 2 + internal/api/handlers.go | 61 +++++++++++++++ internal/api/handlers_test.go | 138 ++++++++++++++++++++++++++++++++-- internal/api/server.go | 16 ++++ internal/certdb/certdb.go | 32 +++++--- 6 files changed, 234 insertions(+), 16 deletions(-) diff --git a/go.mod b/go.mod index f9ce4e7..297bb87 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/kr/text v0.2.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.48.0 // indirect diff --git a/go.sum b/go.sum index 59b32e7..33c78c0 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/api/handlers.go b/internal/api/handlers.go index d1fc33f..d5f2d80 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -12,10 +12,13 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/canonical/gocert/internal/certdb" metrics "github.com/canonical/gocert/internal/metrics" "github.com/canonical/gocert/ui" + "github.com/golang-jwt/jwt" + "golang.org/x/crypto/bcrypt" ) // NewGoCertRouter takes in an environment struct, passes it along to any handlers that will need @@ -36,6 +39,8 @@ func NewGoCertRouter(env *Environment) http.Handler { apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env)) apiV1Router.HandleFunc("DELETE /accounts/{id}", DeleteUserAccount(env)) + apiV1Router.HandleFunc("POST /login", Login(env)) + m := metrics.NewMetricsSubsystem(env.DB) frontendHandler := newFrontendFileServer() @@ -382,6 +387,47 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc { } } +func Login(env *Environment) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var userRequest certdb.User + if err := json.NewDecoder(r.Body).Decode(&userRequest); err != nil { + logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w) + return + } + if userRequest.Username == "" { + logErrorAndWriteResponse("Username is required", http.StatusBadRequest, w) + return + } + if userRequest.Password == "" { + logErrorAndWriteResponse("Password is required", http.StatusBadRequest, w) + return + } + userAccount, err := env.DB.RetrieveUserByUsername(userRequest.Username) + if err != nil { + status := http.StatusInternalServerError + if errors.Is(err, certdb.ErrIdNotFound) { + logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w) + return + } + logErrorAndWriteResponse(err.Error(), status, w) + return + } + if err := bcrypt.CompareHashAndPassword([]byte(userAccount.Password), []byte(userRequest.Password)); err != nil { + logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w) + return + } + jwt, err := generateJWT(userRequest.Username, env.JWTSecret, userAccount.Permissions) + if err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + return + } + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(jwt)); err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + } + } +} + // logErrorAndWriteResponse is a helper function that logs any error and writes it back as an http response func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) { errMsg := fmt.Sprintf("error: %s", msg) @@ -404,3 +450,18 @@ var GeneratePassword = func(length int) (string, error) { } return string(b), nil } + +// Helper function to generate a JWT +func generateJWT(username, jwtSecret string, permissions int) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "username": username, + "permissions": permissions, + "exp": time.Now().Add(time.Hour * 1).Unix(), + }) + tokenString, err := token.SignedString([]byte(jwtSecret)) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index d85bfe0..aa71af4 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -1,6 +1,7 @@ package server_test import ( + "fmt" "io" "log" "net/http" @@ -10,6 +11,7 @@ import ( server "github.com/canonical/gocert/internal/api" "github.com/canonical/gocert/internal/certdb" + "github.com/golang-jwt/jwt" ) const ( @@ -101,10 +103,12 @@ const ( ) const ( - adminUser = `{"username": "testadmin", "password": "admin"}` - validUser = `{"username": "testuser", "password": "user"}` - noPasswordUser = `{"username": "nopass", "password": ""}` - invalidUser = `{"username": "", "password": ""}` + adminUser = `{"username": "testadmin", "password": "admin"}` + validUser = `{"username": "testuser", "password": "user"}` + invalidUser = `{"username": "", "password": ""}` + noPasswordUser = `{"username": "nopass", "password": ""}` + adminUserWrongPass = `{"username": "testadmin", "password": "wrongpass"}` + notExistingUser = `{"username": "not_existing", "password": "user"}` ) func TestGoCertCertificatesHandlers(t *testing.T) { @@ -268,7 +272,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) { method: "POST", path: "/api/v1/certificate_requests/2/certificate", data: validCert2, - response: "4", + response: "1", status: http.StatusCreated, }, { @@ -284,7 +288,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) { method: "POST", path: "/api/v1/certificate_requests/4/certificate/reject", data: "", - response: "4", + response: "1", status: http.StatusAccepted, }, { @@ -300,7 +304,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) { method: "DELETE", path: "/api/v1/certificate_requests/2/certificate", data: "", - response: "4", + response: "1", status: http.StatusAccepted, }, { @@ -479,3 +483,123 @@ func TestGoCertUsersHandlers(t *testing.T) { }) } } + +func TestLogin(t *testing.T) { + testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests") + if err != nil { + log.Fatalf("couldn't create test sqlite db: %s", err) + } + env := &server.Environment{} + env.DB = testdb + env.JWTSecret = "secret" + ts := httptest.NewTLSServer(server.NewGoCertRouter(env)) + defer ts.Close() + + originalFunc := server.GeneratePassword + server.GeneratePassword = func(length int) (string, error) { + return "generatedPassword", nil + } + defer func() { server.GeneratePassword = originalFunc }() + + client := ts.Client() + + testCases := []struct { + desc string + method string + path string + data string + response string + status int + }{ + { + desc: "Create admin user", + method: "POST", + path: "/api/v1/accounts", + data: adminUser, + response: "{\"id\":1,\"password\":\"admin\"}", + status: http.StatusCreated, + }, + { + desc: "Login success", + method: "POST", + path: "/api/v1/login", + data: adminUser, + response: "", + status: http.StatusOK, + }, + { + desc: "Login failure missing username", + method: "POST", + path: "/api/v1/login", + data: invalidUser, + response: "Username is required", + status: http.StatusBadRequest, + }, + { + desc: "Login failure missing password", + method: "POST", + path: "/api/v1/login", + data: noPasswordUser, + response: "Password is required", + status: http.StatusBadRequest, + }, + { + desc: "Login failure invalid password", + method: "POST", + path: "/api/v1/login", + data: adminUserWrongPass, + response: "error: The username or password is incorrect. Try again.", + status: http.StatusUnauthorized, + }, + { + desc: "Login failure invalid username", + method: "POST", + path: "/api/v1/login", + data: notExistingUser, + response: "error: The username or password is incorrect. Try again.", + status: http.StatusUnauthorized, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + req, err := http.NewRequest(tC.method, ts.URL+tC.path, strings.NewReader(tC.data)) + if err != nil { + t.Fatal(err) + } + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + resBody, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if res.StatusCode != tC.status || !strings.Contains(string(resBody), tC.response) { + t.Errorf("expected response did not match.\nExpected vs Received status code: %d vs %d\nExpected vs Received body: \n%s\nvs\n%s\n", tC.status, res.StatusCode, tC.response, string(resBody)) + } + if tC.desc == "Login success" && res.StatusCode == http.StatusOK { + token, parseErr := jwt.Parse(string(resBody), func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + return []byte(env.JWTSecret), nil + }) + if parseErr != nil { + t.Errorf("Error parsing JWT: %v", parseErr) + return + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + if claims["username"] != "testadmin" { + t.Errorf("Username found in JWT does not match expected value.") + } else if int(claims["permissions"].(float64)) != 1 { + t.Errorf("Permissions found in JWT does not match expected value.") + } + } else { + t.Errorf("Invalid JWT token or JWT claims are not readable") + } + } + }) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index ee2ca8b..30b78f1 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -2,7 +2,9 @@ package server import ( + "crypto/rand" "crypto/tls" + "encoding/hex" "errors" "fmt" "log" @@ -16,6 +18,7 @@ import ( type Environment struct { DB *certdb.CertificateRequestsRepository SendPebbleNotifications bool + JWTSecret string } func SendPebbleNotification(key, request_id string) error { @@ -26,6 +29,14 @@ func SendPebbleNotification(key, request_id string) error { return nil } +func generateJWTSecret() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate JWT secret: %w", err) + } + return hex.EncodeToString(bytes), nil +} + // NewServer creates an environment and an http server with handlers that Go can start listening to func NewServer(port int, cert []byte, key []byte, dbPath string, pebbleNotificationsEnabled bool) (*http.Server, error) { serverCerts, err := tls.X509KeyPair(cert, key) @@ -37,9 +48,14 @@ func NewServer(port int, cert []byte, key []byte, dbPath string, pebbleNotificat log.Fatalf("Couldn't connect to database: %s", err) } + jwtSecret, err := generateJWTSecret() + if err != nil { + return nil, err + } env := &Environment{} env.DB = db env.SendPebbleNotifications = pebbleNotificationsEnabled + env.JWTSecret = jwtSecret router := NewGoCertRouter(env) s := &http.Server{ diff --git a/internal/certdb/certdb.go b/internal/certdb/certdb.go index 0d306b3..9dbbf52 100644 --- a/internal/certdb/certdb.go +++ b/internal/certdb/certdb.go @@ -30,11 +30,12 @@ const queryCreateUsersTable = `CREATE TABLE IF NOT EXISTS users ( permissions INTEGER )` const ( - queryGetAllUsers = "SELECT * FROM users" - queryGetUser = "SELECT * FROM users WHERE user_id=?" - queryCreateUser = "INSERT INTO users (username, password, permissions) VALUES (?, ?, ?)" - queryUpdateUser = "UPDATE users SET password=? WHERE user_id=?" - queryDeleteUser = "DELETE FROM users WHERE user_id=?" + queryGetAllUsers = "SELECT * FROM users" + queryGetUser = "SELECT * FROM users WHERE user_id=?" + queryGetUserByUsername = "SELECT * FROM users WHERE username=?" + queryCreateUser = "INSERT INTO users (username, password, permissions) VALUES (?, ?, ?)" + queryUpdateUser = "UPDATE users SET password=? WHERE user_id=?" + queryDeleteUser = "DELETE FROM users WHERE user_id=?" ) // CertificateRequestRepository is the object used to communicate with the established repository. @@ -130,11 +131,11 @@ func (db *CertificateRequestsRepository) UpdateCSR(id string, cert string) (int6 if err != nil { return 0, err } - insertId, err := result.LastInsertId() + affectedRows, err := result.RowsAffected() if err != nil { return 0, err } - return insertId, nil + return affectedRows, nil } // DeleteCSR removes a CSR from the database alongside the certificate that may have been generated for it. @@ -185,6 +186,19 @@ func (db *CertificateRequestsRepository) RetrieveUser(id string) (User, error) { return newUser, nil } +// RetrieveUser retrieves the id, password and the permission level of a user. +func (db *CertificateRequestsRepository) RetrieveUserByUsername(name string) (User, error) { + var newUser User + row := db.conn.QueryRow(queryGetUserByUsername, name) + if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil { + if err.Error() == "sql: no rows in result set" { + return newUser, ErrIdNotFound + } + return newUser, err + } + return newUser, nil +} + // CreateUser creates a new user from a given username, password and permission level. // The permission level 1 represents an admin, and a 0 represents a regular user. // The password passed in should be in plaintext. This function handles hashing and salting the password before storing it in the database. @@ -219,11 +233,11 @@ func (db *CertificateRequestsRepository) UpdateUser(id, password string) (int64, if err != nil { return 0, err } - insertId, err := result.LastInsertId() + affectedRows, err := result.RowsAffected() if err != nil { return 0, err } - return insertId, nil + return affectedRows, nil } // DeleteUser removes a user from the table.