Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adds login endpoint and handlers #40

Merged
merged 8 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.48.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
Expand Down
67 changes: 67 additions & 0 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ import (
"net/http"
"strconv"
"strings"
"time"

"github.com/canonical/gocert/internal/certdb"
metrics "github.com/canonical/gocert/internal/metrics"
"github.com/canonical/gocert/ui"
"github.com/golang-jwt/jwt"
"golang.org/x/crypto/bcrypt"
)

// NewGoCertRouter takes in an environment struct, passes it along to any handlers that will need
Expand All @@ -35,6 +38,8 @@ func NewGoCertRouter(env *Environment) http.Handler {
apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env))
apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env))

apiV1Router.HandleFunc("POST /login", Login(env))

m := metrics.NewMetricsSubsystem(env.DB)
frontendHandler := newFrontendFileServer()

Expand Down Expand Up @@ -360,6 +365,53 @@ func PostUserAccount(env *Environment) http.HandlerFunc {
}
}

func Login(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
var userRequest certdb.User
if err := json.NewDecoder(r.Body).Decode(&userRequest); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
return
}
if userRequest.Username == "" {
logErrorAndWriteResponse("Username is required", http.StatusBadRequest, w)
return
}
if userRequest.Password == "" {
logErrorAndWriteResponse("Password is required", http.StatusBadRequest, w)
return
}
userAccount, err := env.DB.RetrieveUserByUsername(userRequest.Username)
if err != nil {
status := http.StatusInternalServerError
if errors.Is(err, certdb.ErrIdNotFound) {
logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w)
return
}
logErrorAndWriteResponse(err.Error(), status, w)
return
}
if err := bcrypt.CompareHashAndPassword([]byte(userAccount.Password), []byte(userRequest.Password)); err != nil {
logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w)
return
}
jwt, err := generateJWT(userRequest.Username, env.jwtSecret, userAccount.Permissions)
if err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
response, err := json.Marshal(jwt)
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
logErrorAndWriteResponse("Failed to marshal JWT", http.StatusInternalServerError, w)
return
}
w.WriteHeader(http.StatusOK)
if _, err := w.Write(response); err != nil {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
}
}
}

// logErrorAndWriteResponse is a helper function that logs any error and writes it back as an http response
func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) {
errMsg := fmt.Sprintf("error: %s", msg)
Expand All @@ -382,3 +434,18 @@ var GeneratePassword = func(length int) (string, error) {
}
return string(b), nil
}

// Helper function to generate a JWT
func generateJWT(username, jwtSecret string, permissions int) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
"username": username,
"permissions": permissions,
"exp": time.Now().Add(time.Hour * 1).Unix(),
})
tokenString, err := token.SignedString([]byte(jwtSecret))
if err != nil {
return "", err
}

return tokenString, nil
}
64 changes: 50 additions & 14 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ const (
)

const (
adminUser = `{"username": "testadmin", "password": "admin"}`
validUser = `{"username": "testuser", "password": "user"}`
noPasswordUser = `{"username": "nopass", "password": ""}`
invalidUser = `{"username": "", "password": ""}`
adminUser = `{"username": "testadmin", "password": "admin"}`
validUser = `{"username": "testuser", "password": "user"}`
invalidUser = `{"username": "", "password": ""}`
noPasswordUser = `{"username": "nopass", "password": ""}`
adminUserWrongPass = `{"username": "testadmin", "password": "wrongpass"}`
notExistingUser = `{"username": "not_existing", "password": "user"}`
)

func TestGoCertCertificatesHandlers(t *testing.T) {
Expand Down Expand Up @@ -268,7 +270,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) {
method: "POST",
path: "/api/v1/certificate_requests/2/certificate",
data: validCert2,
response: "4",
response: "1",
status: http.StatusCreated,
},
{
Expand All @@ -284,7 +286,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) {
method: "POST",
path: "/api/v1/certificate_requests/4/certificate/reject",
data: "",
response: "4",
response: "1",
status: http.StatusAccepted,
},
{
Expand All @@ -300,7 +302,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) {
method: "DELETE",
path: "/api/v1/certificate_requests/2/certificate",
data: "",
response: "4",
response: "1",
status: http.StatusAccepted,
},
{
Expand Down Expand Up @@ -369,12 +371,6 @@ func TestGoCertUsersHandlers(t *testing.T) {
ts := httptest.NewTLSServer(server.NewGoCertRouter(env))
defer ts.Close()

originalFunc := server.GeneratePassword
server.GeneratePassword = func(length int) (string, error) {
return "generatedPassword", nil
}
defer func() { server.GeneratePassword = originalFunc }()

client := ts.Client()

testCases := []struct {
Expand Down Expand Up @@ -414,7 +410,7 @@ func TestGoCertUsersHandlers(t *testing.T) {
method: "POST",
path: "/api/v1/accounts",
data: noPasswordUser,
response: "{\"id\":3,\"password\":\"generatedPassword\"}",
response: "{\"id\":3,\"password\":",
status: http.StatusCreated,
},
{
Expand All @@ -441,6 +437,46 @@ func TestGoCertUsersHandlers(t *testing.T) {
response: "error: Username is required",
status: http.StatusBadRequest,
},
{
desc: "Login success",
method: "POST",
path: "/api/v1/login",
data: adminUser,
response: "jwt",
status: http.StatusOK,
},
{
desc: "Login failure missing username",
method: "POST",
path: "/api/v1/login",
data: invalidUser,
response: "Username is required",
status: http.StatusBadRequest,
},
{
desc: "Login failure missing password",
method: "POST",
path: "/api/v1/login",
data: noPasswordUser,
response: "Password is required",
status: http.StatusBadRequest,
},
{
desc: "Login failure invalid password",
method: "POST",
path: "/api/v1/login",
data: adminUserWrongPass,
response: "error: The username or password is incorrect. Try again.",
status: http.StatusUnauthorized,
},
{
desc: "Login failure invalid username",
method: "POST",
path: "/api/v1/login",
data: notExistingUser,
response: "error: The username or password is incorrect. Try again.",
status: http.StatusUnauthorized,
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
Expand Down
16 changes: 16 additions & 0 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
package server

import (
"crypto/rand"
"crypto/tls"
"encoding/hex"
"errors"
"fmt"
"log"
Expand All @@ -16,6 +18,7 @@ import (
type Environment struct {
DB *certdb.CertificateRequestsRepository
SendPebbleNotifications bool
jwtSecret string
}

func SendPebbleNotification(key, request_id string) error {
Expand All @@ -26,6 +29,14 @@ func SendPebbleNotification(key, request_id string) error {
return nil
}

func generateJWTSecret() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate JWT secret: %w", err)
}
return hex.EncodeToString(bytes), nil
}

// NewServer creates an environment and an http server with handlers that Go can start listening to
func NewServer(port int, cert []byte, key []byte, dbPath string, pebbleNotificationsEnabled bool) (*http.Server, error) {
serverCerts, err := tls.X509KeyPair(cert, key)
Expand All @@ -37,9 +48,14 @@ func NewServer(port int, cert []byte, key []byte, dbPath string, pebbleNotificat
log.Fatalf("Couldn't connect to database: %s", err)
}

jwtSecret, err := generateJWTSecret()
if err != nil {
return nil, err
}
env := &Environment{}
env.DB = db
env.SendPebbleNotifications = pebbleNotificationsEnabled
env.jwtSecret = jwtSecret
router := NewGoCertRouter(env)

s := &http.Server{
Expand Down
32 changes: 23 additions & 9 deletions internal/certdb/certdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ const queryCreateUsersTable = `CREATE TABLE IF NOT EXISTS users (
permissions INTEGER
)`
const (
queryGetAllUsers = "SELECT * FROM users"
queryGetUser = "SELECT * FROM users WHERE user_id=?"
queryCreateUser = "INSERT INTO users (username, password, permissions) VALUES (?, ?, ?)"
queryUpdateUser = "UPDATE users SET password=? WHERE user_id=?"
queryDeleteUser = "DELETE FROM users WHERE user_id=?"
queryGetAllUsers = "SELECT * FROM users"
queryGetUser = "SELECT * FROM users WHERE user_id=?"
queryGetUserByUsername = "SELECT * FROM users WHERE username=?"
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
queryCreateUser = "INSERT INTO users (username, password, permissions) VALUES (?, ?, ?)"
queryUpdateUser = "UPDATE users SET password=? WHERE user_id=?"
queryDeleteUser = "DELETE FROM users WHERE user_id=?"
)

// CertificateRequestRepository is the object used to communicate with the established repository.
Expand Down Expand Up @@ -130,11 +131,11 @@ func (db *CertificateRequestsRepository) UpdateCSR(id string, cert string) (int6
if err != nil {
return 0, err
}
insertId, err := result.LastInsertId()
affectedRows, err := result.RowsAffected()
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return 0, err
}
return insertId, nil
return affectedRows, nil
}

// DeleteCSR removes a CSR from the database alongside the certificate that may have been generated for it.
Expand Down Expand Up @@ -185,6 +186,19 @@ func (db *CertificateRequestsRepository) RetrieveUser(id string) (User, error) {
return newUser, nil
}

// RetrieveUser retrieves the id, password and the permission level of a user.
func (db *CertificateRequestsRepository) RetrieveUserByUsername(name string) (User, error) {
var newUser User
row := db.conn.QueryRow(queryGetUserByUsername, name)
if err := row.Scan(&newUser.ID, &newUser.Username, &newUser.Password, &newUser.Permissions); err != nil {
if err.Error() == "sql: no rows in result set" {
return newUser, ErrIdNotFound
}
return newUser, err
}
return newUser, nil
}

// CreateUser creates a new user from a given username, password and permission level.
// The permission level 1 represents an admin, and a 0 represents a regular user.
// The password passed in should be in plaintext. This function handles hashing and salting the password before storing it in the database.
Expand Down Expand Up @@ -219,11 +233,11 @@ func (db *CertificateRequestsRepository) UpdateUser(id, password string) (int64,
if err != nil {
return 0, err
}
insertId, err := result.LastInsertId()
affectedRows, err := result.RowsAffected()
kayra1 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return 0, err
}
return insertId, nil
return affectedRows, nil
}

// DeleteUser removes a user from the table.
Expand Down