Skip to content

Commit

Permalink
Refactor auth and permissions (#91)
Browse files Browse the repository at this point in the history
* Use built-in fetch function to get JWKS

* Move user validation to separate file

Also renamed the corresponding test file.

* Move permissions folder outside auth
  • Loading branch information
RichDom2185 authored Aug 4, 2023
1 parent 612b4d7 commit faa326f
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 61 deletions.
23 changes: 3 additions & 20 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package auth

import (
"encoding/json"
"io"
"net/http"
"context"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/sirupsen/logrus"
Expand All @@ -23,24 +21,9 @@ func getJWKS(endpointURL string) jwk.Set {
func setJwkFromEndpoint(endpointURL string) {
// Get JWK from endpoint
logrus.Debugf("Using %s as JWKS source\n", endpointURL)
resp, err := http.Get(endpointURL)
set, err := jwk.Fetch(context.Background(), endpointURL)
if err != nil {
logrus.WithError(err).Error("Failed to get JWK from endpoint")
return
}
defer resp.Body.Close()

// Parse JWK
body, err := io.ReadAll(resp.Body)
if err != nil {
logrus.WithError(err).Error("Failed to read JWK response body")
return
}

set := jwk.NewSet()
err = json.Unmarshal(body, &set)
if err != nil {
logrus.WithError(err).Error("Failed to parse JWK")
logrus.WithError(err).Error("Failed to fetch JWK from endpoint")
return
}

Expand Down
40 changes: 0 additions & 40 deletions internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@ import (
"errors"
"fmt"
"net/http"
"net/url"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/source-academy/stories-backend/internal/config"
"github.com/source-academy/stories-backend/internal/database"
userenums "github.com/source-academy/stories-backend/internal/enums/users"
apierrors "github.com/source-academy/stories-backend/internal/errors"
envutils "github.com/source-academy/stories-backend/internal/utils/env"
"github.com/source-academy/stories-backend/model"
"gorm.io/gorm"
)

const (
Expand Down Expand Up @@ -94,39 +90,3 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler {
})
}
}

func validateAndGetUser(queryString string, db *gorm.DB) (*model.User, error) {
// Validate valid query string
userData, err := url.ParseQuery(queryString)
if err != nil {
return nil, errors.New(invalidTokenSubjectMessage)
}

// Validate required fields
requiredFields := []string{usernameKey, loginProviderKey}
for _, field := range requiredFields {
if !userData.Has(field) {
return nil, errors.New(invalidTokenSubjectMessage)
}
}

// Validate login provider
provider, ok := userenums.LoginProviderFromString(userData.Get(loginProviderKey))
if !ok {
// Invalid/unsupported login provider
return nil, errors.New(invalidTokenSubjectMessage)
}

// Validate user
user := model.User{
Username: userData.Get(usernameKey),
LoginProvider: provider,
}
var dbUser model.User
err = db.Where(&user).First(&dbUser).Error
if err != nil {
return nil, database.HandleDBError(err, "user")
}

return &dbUser, nil
}
2 changes: 1 addition & 1 deletion internal/auth/permission.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package auth
import (
"net/http"

"github.com/source-academy/stories-backend/internal/auth/permissions"
"github.com/source-academy/stories-backend/internal/permissions"
)

func CheckPermissions(r *http.Request, requestedActionPermissions ...permissions.PermissionGroup) (bool, error) {
Expand Down
47 changes: 47 additions & 0 deletions internal/auth/user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package auth

import (
"errors"
"net/url"

"github.com/source-academy/stories-backend/internal/database"
userenums "github.com/source-academy/stories-backend/internal/enums/users"
"github.com/source-academy/stories-backend/model"
"gorm.io/gorm"
)

func validateAndGetUser(queryString string, db *gorm.DB) (*model.User, error) {
// Validate valid query string
userData, err := url.ParseQuery(queryString)
if err != nil {
return nil, errors.New(invalidTokenSubjectMessage)
}

// Validate required fields
requiredFields := []string{usernameKey, loginProviderKey}
for _, field := range requiredFields {
if !userData.Has(field) {
return nil, errors.New(invalidTokenSubjectMessage)
}
}

// Validate login provider
provider, ok := userenums.LoginProviderFromString(userData.Get(loginProviderKey))
if !ok {
// Invalid/unsupported login provider
return nil, errors.New(invalidTokenSubjectMessage)
}

// Validate user
user := model.User{
Username: userData.Get(usernameKey),
LoginProvider: provider,
}
var dbUser model.User
err = db.Where(&user).First(&dbUser).Error
if err != nil {
return nil, database.HandleDBError(err, "user")
}

return &dbUser, nil
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit faa326f

Please sign in to comment.