From d2140432a8c57ca02efc43473e8810f3ce974988 Mon Sep 17 00:00:00 2001 From: Florian Mounier Date: Wed, 16 Oct 2024 12:24:54 +0200 Subject: [PATCH] [IMP] fastapi: Factorize error handling and use it in tests with raise_server_exceptions=False --- fastapi/error_handlers.py | 46 +++++++++++++++++++++++++++-- fastapi/fastapi_dispatcher.py | 39 ++----------------------- fastapi/routers/demo_router.py | 3 +- fastapi/tests/common.py | 31 +++++++++++++++++++- fastapi/tests/test_fastapi_demo.py | 47 +++++++++++++++++++++++++++++- 5 files changed, 124 insertions(+), 42 deletions(-) diff --git a/fastapi/error_handlers.py b/fastapi/error_handlers.py index ca054a581..6f9108b0e 100644 --- a/fastapi/error_handlers.py +++ b/fastapi/error_handlers.py @@ -1,14 +1,56 @@ # Copyright 2022 ACSONE SA/NV # License LGPL-3.0 or later (http://www.gnu.org/licenses/LGPL). +from typing import Tuple - -from starlette.exceptions import WebSocketException +from starlette import status +from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware from starlette.responses import JSONResponse from starlette.websockets import WebSocket +from werkzeug.exceptions import HTTPException as WerkzeugHTTPException + +from odoo.exceptions import AccessDenied, AccessError, MissingError, UserError from fastapi import Request +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError +from fastapi.utils import is_body_allowed_for_status_code + + +def convert_exception_to_status_body(exc: Exception) -> Tuple[int, dict]: + body = {} + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + details = "Internal Server Error" + + if isinstance(exc, WerkzeugHTTPException): + status_code = exc.code + details = exc.description + elif isinstance(exc, HTTPException): + status_code = exc.status_code + details = exc.detail + elif isinstance(exc, RequestValidationError): + status_code = status.HTTP_422_UNPROCESSABLE_ENTITY + details = jsonable_encoder(exc.errors()) + elif isinstance(exc, WebSocketRequestValidationError): + status_code = status.WS_1008_POLICY_VIOLATION + details = jsonable_encoder(exc.errors()) + elif isinstance(exc, (AccessDenied, AccessError)): + status_code = status.HTTP_403_FORBIDDEN + details = "AccessError" + elif isinstance(exc, MissingError): + status_code = status.HTTP_404_NOT_FOUND + details = "MissingError" + elif isinstance(exc, UserError): + status_code = status.HTTP_400_BAD_REQUEST + details = exc.args[0] + + if is_body_allowed_for_status_code(status_code): + # use the same format as in + # fastapi.exception_handlers.http_exception_handler + body = {"detail": details} + return status_code, body + # we need to monkey patch the ServerErrorMiddleware and ExceptionMiddleware classes # to ensure that all the exceptions that are handled by these specific diff --git a/fastapi/fastapi_dispatcher.py b/fastapi/fastapi_dispatcher.py index db2f3f3c7..1a8eb3532 100644 --- a/fastapi/fastapi_dispatcher.py +++ b/fastapi/fastapi_dispatcher.py @@ -4,18 +4,10 @@ from contextlib import contextmanager from io import BytesIO -from starlette import status -from starlette.exceptions import HTTPException -from werkzeug.exceptions import HTTPException as WerkzeugHTTPException - -from odoo.exceptions import AccessDenied, AccessError, MissingError, UserError from odoo.http import Dispatcher, request -from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError -from fastapi.utils import is_body_allowed_for_status_code - from .context import odoo_env_ctx +from .error_handlers import convert_exception_to_status_body class FastApiDispatcher(Dispatcher): @@ -47,34 +39,7 @@ def dispatch(self, endpoint, args): def handle_error(self, exc): headers = getattr(exc, "headers", None) - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - details = "Internal Server Error" - if isinstance(exc, WerkzeugHTTPException): - status_code = exc.code - details = exc.description - elif isinstance(exc, HTTPException): - status_code = exc.status_code - details = exc.detail - elif isinstance(exc, RequestValidationError): - status_code = status.HTTP_422_UNPROCESSABLE_ENTITY - details = jsonable_encoder(exc.errors()) - elif isinstance(exc, WebSocketRequestValidationError): - status_code = status.WS_1008_POLICY_VIOLATION - details = jsonable_encoder(exc.errors()) - elif isinstance(exc, (AccessDenied, AccessError)): - status_code = status.HTTP_403_FORBIDDEN - details = "AccessError" - elif isinstance(exc, MissingError): - status_code = status.HTTP_404_NOT_FOUND - details = "MissingError" - elif isinstance(exc, UserError): - status_code = status.HTTP_400_BAD_REQUEST - details = exc.args[0] - body = {} - if is_body_allowed_for_status_code(status_code): - # use the same format as in - # fastapi.exception_handlers.http_exception_handler - body = {"detail": details} + status_code, body = convert_exception_to_status_body(exc) return self.request.make_json_response( body, status=status_code, headers=headers ) diff --git a/fastapi/routers/demo_router.py b/fastapi/routers/demo_router.py index 01e9eef1d..e6ce0fe3e 100644 --- a/fastapi/routers/demo_router.py +++ b/fastapi/routers/demo_router.py @@ -4,6 +4,7 @@ The demo router is a router that demonstrates how to use the fastapi integration with odoo. """ + from typing import Annotated from psycopg2 import errorcodes @@ -66,7 +67,7 @@ async def get_lang(env: Annotated[Environment, Depends(odoo_env)]): @router.get("/demo/who_ami") async def who_ami( - partner: Annotated[Partner, Depends(authenticated_partner)] + partner: Annotated[Partner, Depends(authenticated_partner)], ) -> DemoUserInfo: """Who am I? diff --git a/fastapi/tests/common.py b/fastapi/tests/common.py index 2ae0dcf4a..30a3af7e4 100644 --- a/fastapi/tests/common.py +++ b/fastapi/tests/common.py @@ -1,9 +1,14 @@ # Copyright 2023 ACSONE SA/NV # License LGPL-3.0 or later (http://www.gnu.org/licenses/LGPL). +import logging from contextlib import contextmanager from functools import partial from typing import Any, Callable, Dict +from starlette import status +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + from odoo.api import Environment from odoo.tests import tagged from odoo.tests.common import TransactionCase @@ -19,6 +24,25 @@ authenticated_partner_impl, optionally_authenticated_partner_impl, ) +from ..error_handlers import convert_exception_to_status_body + +_logger = logging.getLogger(__name__) + + +def default_exception_handler(request: Request, exc: Exception) -> Response: + """ + Default exception handler that returns a response with the exception details. + """ + status_code, body = convert_exception_to_status_body(exc) + + if status_code == status.HTTP_500_INTERNAL_SERVER_ERROR: + # In testing we want to see the exception details of 500 errors + _logger.error("[%d] Error occurred: %s", exc_info=exc) + + return JSONResponse( + status_code=status_code, + content=body, + ) @tagged("post_install", "-at_install") @@ -123,13 +147,18 @@ def _create_test_client( if router: app.include_router(router) app.dependency_overrides = dependencies + + if not raise_server_exceptions: + # Handle exceptions as in FastAPIDispatcher + app.exception_handlers.setdefault(Exception, default_exception_handler) + ctx_token = odoo_env_ctx.set(env) testclient_kwargs = testclient_kwargs or {} try: yield TestClient( app, raise_server_exceptions=raise_server_exceptions, - **testclient_kwargs + **testclient_kwargs, ) finally: odoo_env_ctx.reset(ctx_token) diff --git a/fastapi/tests/test_fastapi_demo.py b/fastapi/tests/test_fastapi_demo.py index 1692e69f3..43503b6d4 100644 --- a/fastapi/tests/test_fastapi_demo.py +++ b/fastapi/tests/test_fastapi_demo.py @@ -5,11 +5,13 @@ from requests import Response +from odoo.exceptions import UserError + from fastapi import status from ..dependencies import fastapi_endpoint from ..routers import demo_router -from ..schemas import DemoEndpointAppInfo +from ..schemas import DemoEndpointAppInfo, DemoExceptionType from .common import FastAPITransactionCase @@ -61,3 +63,46 @@ def test_endpoint_info(self) -> None: response.json(), DemoEndpointAppInfo.model_validate(demo_app).model_dump(by_alias=True), ) + + def test_exception_raised(self) -> None: + with self.assertRaisesRegex(UserError, "User Error"): + with self._create_test_client() as test_client: + test_client.get( + "/demo/exception", + params={ + "exception_type": DemoExceptionType.user_error.value, + "error_message": "User Error", + }, + ) + with self.assertRaisesRegex(NotImplementedError, "Bare Exception"): + with self._create_test_client() as test_client: + test_client.get( + "/demo/exception", + params={ + "exception_type": DemoExceptionType.bare_exception.value, + "error_message": "Bare Exception", + }, + ) + + def test_exception_not_raised(self) -> None: + with self._create_test_client(raise_server_exceptions=False) as test_client: + response: Response = test_client.get( + "/demo/exception", + params={ + "exception_type": DemoExceptionType.user_error.value, + "error_message": "User Error", + }, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"detail": "User Error"}) + + with self._create_test_client(raise_server_exceptions=False) as test_client: + response: Response = test_client.get( + "/demo/exception", + params={ + "exception_type": DemoExceptionType.bare_exception.value, + "error_message": "Bare Exception", + }, + ) + self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + self.assertDictEqual(response.json(), {"detail": "Internal Server Error"})