Skip to content

Commit

Permalink
Merge pull request #1203 from nidhey27/refactor/basic-auth-middleware
Browse files Browse the repository at this point in the history
Issue - 996 | Refactor BasicAuthMiddleware to reduce cognitive complexity
  • Loading branch information
vipul-rawat authored Dec 4, 2024
2 parents 7bb3009 + 40de59d commit fba34dd
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 34 deletions.
68 changes: 34 additions & 34 deletions pkg/gofr/http/middleware/basic_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,9 @@ func BasicAuthMiddleware(basicAuthProvider BasicAuthProvider) func(handler http.
return
}

authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Unauthorized: Authorization header missing", http.StatusUnauthorized)
return
}

scheme, credentials, found := strings.Cut(authHeader, " ")
if !found || scheme != "Basic" {
http.Error(w, "Unauthorized: Invalid Authorization header", http.StatusUnauthorized)
return
}

payload, err := base64.StdEncoding.DecodeString(credentials)
if err != nil {
http.Error(w, "Unauthorized: Invalid credentials format", http.StatusUnauthorized)
return
}

username, password, found := strings.Cut(string(payload), ":")
if !found {
http.Error(w, "Unauthorized: Invalid credentials", http.StatusUnauthorized)
username, password, ok := parseBasicAuth(r)
if !ok {
http.Error(w, "Unauthorized: Invalid or missing Authorization header", http.StatusUnauthorized)
return
}

Expand All @@ -58,26 +40,44 @@ func BasicAuthMiddleware(basicAuthProvider BasicAuthProvider) func(handler http.
}

ctx := context.WithValue(r.Context(), Username, username)
*r = *r.Clone(ctx)

handler.ServeHTTP(w, r)
handler.ServeHTTP(w, r.Clone(ctx))
})
}
}

// parseBasicAuth extracts and decodes the username and password from the Authorization header.
func parseBasicAuth(r *http.Request) (username, password string, ok bool) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", "", false
}

scheme, credentials, found := strings.Cut(authHeader, " ")
if !found || scheme != "Basic" {
return "", "", false
}

payload, err := base64.StdEncoding.DecodeString(credentials)
if err != nil {
return "", "", false
}

username, password, found = strings.Cut(string(payload), ":")
if !found { // Ensure both username and password are returned as empty if colon separator is missing
return "", "", false
}

return username, password, true
}

// validateCredentials checks the provided username and password against the BasicAuthProvider.
func validateCredentials(provider BasicAuthProvider, username, password string) bool {
// If ValidateFunc is provided, use it.
if provider.ValidateFunc != nil {
if provider.ValidateFunc(username, password) {
return true
}
if provider.ValidateFunc != nil && provider.ValidateFunc(username, password) {
return true
}

// If ValidateFuncWithDatasources is provided, use it.
if provider.ValidateFuncWithDatasources != nil {
if provider.ValidateFuncWithDatasources(provider.Container, username, password) {
return true
}
if provider.ValidateFuncWithDatasources != nil && provider.ValidateFuncWithDatasources(provider.Container, username, password) {
return true
}

storedPass, ok := provider.Users[username]
Expand Down
134 changes: 134 additions & 0 deletions pkg/gofr/http/middleware/basic_auth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"encoding/base64"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -179,3 +180,136 @@ func Test_BasicAuthMiddleware_well_known(t *testing.T) {

assert.Equal(t, "Success", rr.Body.String(), "TEST Failed.\n")
}

func TestParseBasicAuth(t *testing.T) {
testCases := []struct {
name string
authHeader string
expectedUser string
expectedPass string
expectedOk bool
}{
{
name: "Valid Basic Auth",
authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte("user:password")),
expectedUser: "user",
expectedPass: "password",
expectedOk: true,
},
{
name: "Invalid Scheme",
authHeader: "Bearer token",
expectedUser: "",
expectedPass: "",
expectedOk: false,
},
{
name: "Invalid Encoding",
authHeader: "Basic invalid_base64",
expectedUser: "",
expectedPass: "",
expectedOk: false,
},
{
name: "Missing Colon Separator",
authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte("user")),
expectedUser: "",
expectedPass: "",
expectedOk: false,
},
{
name: "Empty Authorization Header",
authHeader: "",
expectedUser: "",
expectedPass: "",
expectedOk: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set("Authorization", tc.authHeader)

username, password, ok := parseBasicAuth(req)

assert.Equal(t, tc.expectedOk, ok)
assert.Equal(t, tc.expectedUser, username)
assert.Equal(t, tc.expectedPass, password)
})
}
}

func TestValidateCredentials(t *testing.T) {
validationFunc := func(user, pass string) bool {
return user == "validUser" && pass == "validPass"
}

validationFuncWithDB := func(_ *container.Container, user, pass string) bool {
return user == "dbUser" && pass == "dbPass"
}

provider := BasicAuthProvider{
Users: map[string]string{"storedUser": "storedPass"},
ValidateFunc: validationFunc,
ValidateFuncWithDatasources: validationFuncWithDB,
Container: &container.Container{},
}

testCases := []struct {
name string
username string
password string
expected bool
}{
{
name: "Valid Credentials with ValidateFunc",
username: "validUser",
password: "validPass",
expected: true,
},
{
name: "Valid Credentials with ValidateFuncWithDatasources",
username: "dbUser",
password: "dbPass",
expected: true,
},
{
name: "Valid Credentials with Stored User",
username: "storedUser",
password: "storedPass",
expected: true,
},
{
name: "Invalid Credentials",
username: "invalidUser",
password: "invalidPass",
expected: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := validateCredentials(provider, tc.username, tc.password)
assert.Equal(t, tc.expected, result)
})
}
}

func TestBasicAuthMiddleware_NoAuthHeader(t *testing.T) {
authProvider := BasicAuthProvider{
Users: map[string]string{"user": "password"},
}

middleware := BasicAuthMiddleware(authProvider)
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})

req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
rr := httptest.NewRecorder()
middleware(handler).ServeHTTP(rr, req)

assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.Contains(t, rr.Body.String(), "Unauthorized: Invalid or missing Authorization header")
}

0 comments on commit fba34dd

Please sign in to comment.