Skip to content

Commit

Permalink
Change: prevent status code being set for unauthorized (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmstff authored Mar 12, 2024
1 parent c1de122 commit 230de56
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 7 additions & 3 deletions auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"context"
"errors"
"fmt"
"net/http"

"github.com/gin-gonic/gin"
)
Expand All @@ -28,13 +27,13 @@ func NewGinAuthMiddleware(parseRequestFunc func(ctx context.Context, authorizati
}

if err := ctx.ShouldBindHeader(&header); err != nil {
_ = ctx.AbortWithError(http.StatusUnauthorized, fmt.Errorf("could not bind header: %w", err))
AbortWithError(ctx, fmt.Errorf("could not bind header: %w", err))
return
}

userContext, err := parseRequestFunc(ctx, header.Authorization, header.Origin)
if err != nil {
_ = ctx.AbortWithError(http.StatusUnauthorized, fmt.Errorf("authorization failed: %w", err))
AbortWithError(ctx, fmt.Errorf("authorization failed: %w", err))
return
}

Expand All @@ -43,3 +42,8 @@ func NewGinAuthMiddleware(parseRequestFunc func(ctx context.Context, authorizati
ctx.Next()
}, nil
}

func AbortWithError(ctx *gin.Context, err error) {
_ = ctx.Error(err)
ctx.Abort()
}
2 changes: 0 additions & 2 deletions auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ func TestGinAuthMiddleware(t *testing.T) {
ctx.Request, _ = http.NewRequest("GET", "/", nil)
auth(ctx)

assert.Equal(t, http.StatusUnauthorized, w.Code)
require.Len(t, ctx.Errors, 1)
assert.ErrorContains(t, ctx.Errors[0], "could not bind header")
assert.ErrorContains(t, ctx.Errors[0], "Authorization")
Expand All @@ -65,7 +64,6 @@ func TestGinAuthMiddleware(t *testing.T) {
ctx.Request.Header.Add("Origin", "origin")
auth(ctx)

assert.Equal(t, http.StatusUnauthorized, w.Code)
require.Len(t, ctx.Errors, 1)
assert.ErrorContains(t, ctx.Errors[0], "authorization failed")
assert.ErrorContains(t, ctx.Errors[0], "test error")
Expand Down

0 comments on commit 230de56

Please sign in to comment.