Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace 403 with 401 status code #302

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/jwt/middleware.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Setup
-----

``JWTMiddleware`` wraps an ASGI app, and ensures a valid token is passed in the header.
Otherwise a 403 error is returned. If the token is valid, the corresponding
Otherwise a 401 error is returned. If the token is valid, the corresponding
``user_id`` is added to the ASGI ``scope``.

blacklist
Expand Down
28 changes: 22 additions & 6 deletions piccolo_api/jwt_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jwt
from piccolo.apps.user.tables import BaseUser
from starlette.exceptions import HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED
from starlette.types import ASGIApp


Expand Down Expand Up @@ -126,7 +127,7 @@ async def get_user(
async def __call__(self, scope, receive, send):
"""
Add the user_id to the scope if a JWT token is available, and the user
is recognised, otherwise raise a 403 HTTP error.
is recognised, otherwise raise a 401 HTTP error.
"""
allow_unauthenticated = self.allow_unauthenticated

Expand All @@ -142,7 +143,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)

if await self.blacklist.in_blacklist(token):
error = JWTError.token_revoked.value
Expand All @@ -154,7 +158,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)

try:
token_dict = jwt.decode(token, self.secret, algorithms=["HS256"])
Expand All @@ -168,7 +175,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)
except jwt.exceptions.InvalidSignatureError:
error = JWTError.token_invalid.value
if allow_unauthenticated:
Expand All @@ -179,7 +189,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)

user = await self.get_user(token_dict)
if user is None:
Expand All @@ -192,7 +205,10 @@ async def __call__(self, scope, receive, send):
)
return
else:
raise HTTPException(status_code=403, detail=error)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=error,
)

await self.asgi(
extend_scope(scope, {"user_id": user.id}), receive, send
Expand Down
5 changes: 3 additions & 2 deletions piccolo_api/mfa/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from starlette.endpoints import HTTPEndpoint
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED

from piccolo_api.mfa.provider import MFAProvider
from piccolo_api.shared.auth.styles import Styles
Expand Down Expand Up @@ -64,7 +65,7 @@ def _render_cancel_template(
template = environment.get_template("mfa_cancel.html")

return HTMLResponse(
status_code=400,
status_code=HTTP_400_BAD_REQUEST,
content=template.render(
styles=self._styles,
csrftoken=request.scope.get("csrftoken"),
Expand Down Expand Up @@ -110,7 +111,7 @@ async def post(self, request: Request):
):
return self._render_register_template(
request=request,
status_code=403,
status_code=HTTP_401_UNAUTHORIZED,
extra_context={"error": "Incorrect password"},
)

Expand Down
30 changes: 19 additions & 11 deletions piccolo_api/session_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PlainTextResponse,
RedirectResponse,
)
from starlette.status import HTTP_303_SEE_OTHER
from starlette.status import HTTP_303_SEE_OTHER, HTTP_401_UNAUTHORIZED

from piccolo_api.mfa.provider import MFAProvider
from piccolo_api.session_auth.tables import SessionsBase
Expand Down Expand Up @@ -92,7 +92,8 @@ async def post(self, request: Request) -> Response:
cookie = request.cookies.get(self._cookie_name, None)
if not cookie:
raise HTTPException(
status_code=401, detail="The session cookie wasn't found."
status_code=HTTP_401_UNAUTHORIZED,
detail="The session cookie wasn't found.",
)
await self._session_table.remove_session(token=cookie)

Expand Down Expand Up @@ -204,11 +205,14 @@ def _get_error_response(
) -> Response:
if response_format == "html":
return self._render_template(
request, template_context={"error": error}, status_code=401
request,
template_context={"error": error},
status_code=HTTP_401_UNAUTHORIZED,
)
else:
return PlainTextResponse(
status_code=401, content=f"Login failed: {error}"
status_code=HTTP_401_UNAUTHORIZED,
content=f"Login failed: {error}",
)

async def get(self, request: Request) -> HTMLResponse:
Expand Down Expand Up @@ -261,7 +265,8 @@ async def post(self, request: Request) -> Response:
)
else:
raise HTTPException(
status_code=401, detail=validate_response
status_code=HTTP_401_UNAUTHORIZED,
detail=validate_response,
)

# Attempt login
Expand Down Expand Up @@ -314,7 +319,8 @@ async def post(self, request: Request) -> Response:
)
else:
raise HTTPException(
status_code=401, detail=message
status_code=HTTP_401_UNAUTHORIZED,
detail=message,
)

# Work out which MFA provider to use:
Expand All @@ -325,7 +331,7 @@ async def post(self, request: Request) -> Response:

if mfa_provider_name is None:
raise HTTPException(
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
detail="MFA provider must be specified",
)

Expand All @@ -337,13 +343,13 @@ async def post(self, request: Request) -> Response:

if len(filtered_mfa_providers) == 0:
raise HTTPException(
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
detail="MFA provider not recognised.",
)

if len(filtered_mfa_providers) > 1:
raise HTTPException(
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
detail=(
"Multiple matching MFA providers found."
),
Expand All @@ -368,7 +374,7 @@ async def post(self, request: Request) -> Response:
)
else:
raise HTTPException(
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
detail="MFA failed",
)

Expand Down Expand Up @@ -404,7 +410,9 @@ async def post(self, request: Request) -> Response:
},
)
else:
raise HTTPException(status_code=401, detail="Login failed")
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Login failed"
)

now = datetime.now()
expiry_date = now + self._session_expiry
Expand Down
3 changes: 2 additions & 1 deletion piccolo_api/shared/middleware/junction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from starlette.exceptions import HTTPException
from starlette.routing import Router
from starlette.status import HTTP_404_NOT_FOUND
from starlette.types import Receive, Scope, Send


Expand All @@ -22,4 +23,4 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
else:
return

raise HTTPException(status_code=404)
raise HTTPException(status_code=HTTP_404_NOT_FOUND)
6 changes: 4 additions & 2 deletions piccolo_api/token_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette.endpoints import HTTPEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.status import HTTP_401_UNAUTHORIZED

from .tables import TokenAuth

Expand Down Expand Up @@ -61,11 +62,12 @@ async def post(self, request: Request) -> Response:
else:
return Response(
content="The credentials were incorrect",
status_code=401,
status_code=HTTP_401_UNAUTHORIZED,
)
else:
return Response(
content="No credentials were found.", status_code=401
content="No credentials were found.",
status_code=HTTP_401_UNAUTHORIZED,
)


Expand Down
10 changes: 5 additions & 5 deletions tests/jwt_auth/test_jwt_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_empty_token(self):
with self.assertRaises(HTTPException):
response = client.get("/")

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(
response.json()["detail"], JWTError.token_not_found.value
)
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_expired_token(self):
with self.assertRaises(HTTPException):
response = client.get("/", headers=headers)

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(
response.json()["detail"], JWTError.token_expired.value
)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_wrong_secret(self):
with self.assertRaises(HTTPException):
response = client.get("/", headers=headers)

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(
response.json()["detail"], JWTError.token_invalid.value
)
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_missing_expiry(self):
with self.assertRaises(HTTPException):
response = client.get("/", headers=headers)

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(
response.json()["detail"], JWTError.token_expired.value
)
Expand All @@ -188,7 +188,7 @@ def test_token_without_user_id(self):
with self.assertRaises(HTTPException):
response = client.get("/", headers=headers)

self.assertEqual(response.status_code, 403)
self.assertEqual(response.status_code, 401)
self.assertEqual(response.content, b"")

# allow_unauthenticated
Expand Down
Loading