Skip to content

Commit

Permalink
Fix exception overhandling in middleware (GH-44)
Browse files Browse the repository at this point in the history
Co-authored-by: David García Garzón <[email protected]>
  • Loading branch information
ArtyomVancyan and vokimon authored Aug 19, 2024
2 parents c7ca1ce + ad31cba commit e388257
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 15 deletions.
9 changes: 7 additions & 2 deletions docs/integration/integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,24 @@ section covers its integration into a FastAPI app.

The `OAuth2Middleware` is an authentication middleware which means that its usage makes the `user` and `auth` attributes
available in the [request](https://www.starlette.io/requests/) context. It has a mandatory argument `config` of
[`OAuth2Config`](/integration/configuration#oauth2config) instance that has been discussed at the previous section and
an optional argument `callback` which is a callable that is called when the authentication succeeds.
[`OAuth2Config`](/integration/configuration#oauth2config) instance that has been discussed in the previous section and
optional arguments `callback` and `on_error` that accept callables as values and are called when the authentication
succeeds and fails correspondingly.

```python
app: FastAPI

def on_auth_success(auth: Auth, user: User):
"""This could be async function as well."""

def on_auth_error(conn: HTTPConnection, exc: Exception) -> Response:
return JSONResponse({"detail": str(exc)}, status_code=400)

app.add_middleware(
OAuth2Middleware,
config=OAuth2Config(...),
callback=on_auth_success,
on_error=on_auth_error,
)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/references/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def error_handler(request: Request, exc: OAuth2AuthenticationError):
return RedirectResponse(url="/login", status_code=303)
```

The complete list of exceptions is the following.
The complete list of exceptions raised by the middleware is the following.

- `OAuth2Error` - Base exception for all errors raised by the FastAPI OAuth2 library.
- `OAuth2AuthenticationError` - An exception is raised when the authentication fails.
Expand Down
2 changes: 1 addition & 1 deletion src/fastapi_oauth2/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.1.0"
23 changes: 12 additions & 11 deletions src/fastapi_oauth2/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from jose.jwt import encode as jwt_encode
from starlette.authentication import AuthCredentials
from starlette.authentication import AuthenticationBackend
from starlette.authentication import AuthenticationError
from starlette.authentication import BaseUser
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import HTTPConnection
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.responses import Response
from starlette.types import ASGIApp
from starlette.types import Receive
from starlette.types import Scope
Expand All @@ -28,7 +30,6 @@
from .claims import Claims
from .config import OAuth2Config
from .core import OAuth2Core
from .exceptions import OAuth2AuthenticationError


class Auth(AuthCredentials):
Expand Down Expand Up @@ -108,9 +109,12 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
if not scheme or not param:
return Auth(), User()

token_data = Auth.jwt_decode(param)
try:
token_data = Auth.jwt_decode(param)
except JOSEError as e:
raise AuthenticationError(str(e))
if token_data["exp"] and token_data["exp"] < int(datetime.now(timezone.utc).timestamp()):
raise OAuth2AuthenticationError(401, "Token expired")
raise AuthenticationError("Token expired")

user = User(token_data)
auth = Auth(user.pop("scope", []))
Expand All @@ -135,7 +139,7 @@ def __init__(
app: ASGIApp,
config: Union[OAuth2Config, dict],
callback: Callable[[Auth, User], Union[Awaitable[None], None]] = None,
**kwargs, # AuthenticationMiddleware kwargs
on_error: Optional[Callable[[HTTPConnection, AuthenticationError], Response]] = None,
) -> None:
"""Initiates the middleware with the given configuration.
Expand All @@ -148,13 +152,10 @@ def __init__(
elif not isinstance(config, OAuth2Config):
raise TypeError("config is not a valid type")
self.default_application_middleware = app
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs)
on_error = on_error or AuthenticationMiddleware.default_on_error
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), on_error=on_error)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
try:
return await self.auth_middleware(scope, receive, send)
except (JOSEError, Exception) as e:
middleware = PlainTextResponse(str(e), status_code=401)
return await middleware(scope, receive, send)
return await self.auth_middleware(scope, receive, send)
await self.default_application_middleware(scope, receive, send)
67 changes: 67 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from fastapi.responses import JSONResponse
from httpx import AsyncClient
from jose import jwt


@pytest.mark.anyio
Expand All @@ -26,3 +28,68 @@ async def test_middleware_on_logout(get_app):

response = await client.get("/user")
assert response.status_code == 403 # Forbidden


@pytest.mark.anyio
async def test_middleware_do_not_interfere_user_errors(get_app):
app = get_app()

@app.get("/unexpected_error")
def my_entry_point():
raise NameError # Intended code error

async with AsyncClient(app=app, base_url="http://test") as client:
with pytest.raises(NameError):
await client.get("/unexpected_error")


@pytest.mark.anyio
async def test_middleware_ignores_custom_exceptions(get_app):
class MyCustomException(Exception):
pass

app = get_app()

@app.get("/custom_exception")
def my_entry_point():
raise MyCustomException()

async with AsyncClient(app=app, base_url="http://test") as client:
with pytest.raises(MyCustomException):
await client.get("/custom_exception")


@pytest.mark.anyio
async def test_middleware_ignores_handled_custom_exceptions(get_app):
class MyHandledException(Exception):
pass

app = get_app()

@app.exception_handler(MyHandledException)
async def unicorn_exception_handler(request, exc):
return JSONResponse(
status_code=418,
content={"details": "I am a custom Teapot!"},
)

@app.get("/handled_exception")
def my_entry_point():
raise MyHandledException()

async with AsyncClient(app=app, base_url="http://test") as client:
response = await client.get("/handled_exception")
assert response.status_code == 418 # I am a teapot!
assert response.json() == {"details": "I am a custom Teapot!"}


@pytest.mark.anyio
async def test_middleware_reports_invalid_jwt(get_app):
async with AsyncClient(app=get_app(with_ssr=False), base_url="http://test") as client:
# Insert a bad token instead
badtoken = jwt.encode({"bad": "token"}, "badsecret", "HS256")
client.cookies.update(dict(Authorization=f"Bearer: {badtoken}"))

response = await client.get("/user")
assert response.status_code == 400
assert response.text == "Signature verification failed."

0 comments on commit e388257

Please sign in to comment.