From 4fb1e46d2efd3c0bb97a2de56ebfd59deb9548c0 Mon Sep 17 00:00:00 2001 From: Jacob Coffee Date: Wed, 27 Mar 2024 01:55:59 -0500 Subject: [PATCH 01/19] fix(docs): remove reverted changelog item --- docs/release-notes/changelog.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/release-notes/changelog.rst b/docs/release-notes/changelog.rst index 75caa3a854..5bbc846f09 100644 --- a/docs/release-notes/changelog.rst +++ b/docs/release-notes/changelog.rst @@ -6,12 +6,6 @@ .. changelog:: 2.7.1 :date: 2024-03-22 - .. change:: add default encoders for `Enums` and `EnumMeta` - :type: bugfix - :pr: 3193 - - This addresses an issue when serializing ``Enums`` that was reported in discord. - .. change:: replace TestClient.__enter__ return type with Self :type: bugfix :pr: 3194 From 012bcee5b1b4f20b98ebd6998b4d676b92397c0d Mon Sep 17 00:00:00 2001 From: Jacob Coffee Date: Wed, 27 Mar 2024 02:09:30 -0500 Subject: [PATCH 02/19] docs: clarify preferred platforms for sponsoring (#3269) --- README.md | 6 +++--- docs/PYPI_README.md | 6 +++--- docs/index.rst | 5 ++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index c9d335d92a..2cc574231e 100644 --- a/README.md +++ b/README.md @@ -128,10 +128,10 @@ A **huge** thanks to our sponsors: Check out our sponsors in the docs -If you would like to support the work that we do please consider [becoming a sponsor][sponsor-github] -on [GitHub][sponsor-github] or [Open Collective][sponsor-oc]. +If you would like to support the work that we do please consider [becoming a sponsor][sponsor-polar] +via [Polar.sh][sponsor-polar] (preferred), [GitHub][sponsor-github] or [Open Collective][sponsor-oc]. -We also participate in pledge-based sponsorship with [Polar][sponsor-polar]. +Also, exclusively with [Polar][sponsor-polar], you can engage in pledge-based sponsorships. [sponsor-github]: https://github.com/sponsors/litestar-org [sponsor-oc]: https://opencollective.com/litestar diff --git a/docs/PYPI_README.md b/docs/PYPI_README.md index da66be1a3f..361c186589 100644 --- a/docs/PYPI_README.md +++ b/docs/PYPI_README.md @@ -125,10 +125,10 @@ A **huge** thanks to our sponsors: Check out our sponsors in the docs -If you would like to support the work that we do please consider [becoming a sponsor][sponsor-github] -on [GitHub][sponsor-github] or [Open Collective][sponsor-oc]. +If you would like to support the work that we do please consider [becoming a sponsor][sponsor-polar] +via [Polar.sh][sponsor-polar] (preferred), [GitHub][sponsor-github] or [Open Collective][sponsor-oc]. -We also participate in pledge-based sponsorship with [Polar][sponsor-polar]. +Also, exclusively with [Polar][sponsor-polar], you can engage in pledge-based sponsorships. [sponsor-github]: https://github.com/sponsors/litestar-org [sponsor-oc]: https://opencollective.com/litestar diff --git a/docs/index.rst b/docs/index.rst index be01d21e10..d8c2822765 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -164,11 +164,10 @@ A huge thank you to our current sponsors: We invite organizations and individuals to join our sponsorship program. -By becoming a sponsor on platforms like `GitHub `_ +By becoming a sponsor on `Polar `_ (preferred), or other platforms like `GitHub `_ and `Open Collective `_, you can play a pivotal role in our project's growth. -Additionally, we engage in pledge-based sponsorship opportunities through `Polar `_. - +Also, exclusively with `Polar `_, you can engage in pledge-based sponsorships. .. _sponsor-github: https://github.com/sponsors/litestar-org .. _sponsor-oc: https://opencollective.com/litestar From e8183312ce3549057a4d13cd2733947165c40060 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Wed, 27 Mar 2024 18:16:09 +1000 Subject: [PATCH 03/19] docs: add sherbang as a contributor for doc (#3270) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] --------- Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 9 +++++++++ README.md | 1 + 2 files changed, 10 insertions(+) diff --git a/.all-contributorsrc b/.all-contributorsrc index ddbacfcd34..e0ca6b8c09 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -1688,6 +1688,15 @@ "contributions": [ "bug" ] + }, + { + "login": "sherbang", + "name": "sherbang", + "avatar_url": "https://avatars.githubusercontent.com/u/275015?v=4", + "profile": "https://github.com/sherbang", + "contributions": [ + "doc" + ] } ], "contributorsPerLine": 7, diff --git a/README.md b/README.md index 2cc574231e..446f8f6c27 100644 --- a/README.md +++ b/README.md @@ -553,6 +553,7 @@ see [the contribution guide](CONTRIBUTING.rst). James Bennett
James Bennett

🐛 + sherbang
sherbang

📖 From b38f530bd806bec157a9f31c4a09efc070f5f772 Mon Sep 17 00:00:00 2001 From: Jacob Coffee Date: Wed, 27 Mar 2024 19:59:39 -0500 Subject: [PATCH 04/19] fix(cli): remove duplicate rich-click config options (#3274) --- litestar/cli/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/litestar/cli/__init__.py b/litestar/cli/__init__.py index f6c366e495..7dabefaae9 100644 --- a/litestar/cli/__init__.py +++ b/litestar/cli/__init__.py @@ -14,8 +14,6 @@ click.rich_click.USE_MARKDOWN = False click.rich_click.SHOW_ARGUMENTS = True click.rich_click.GROUP_ARGUMENTS_OPTIONS = True - click.rich_click.SHOW_ARGUMENTS = True - click.rich_click.GROUP_ARGUMENTS_OPTIONS = True click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic" click.rich_click.ERRORS_SUGGESTION = "" click.rich_click.ERRORS_EPILOGUE = "" From d2cb891fcc5264620674aa0010f0a0aab682aa85 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Thu, 28 Mar 2024 20:18:34 +1000 Subject: [PATCH 05/19] ci: run slotscheck cli as ci job (#3275) This PR moves running slotscheck from within pre-commit to its own job in ci (same as we've done with other tools that require the configured environment to run). This partially fixes an issue identified in #3251 where the `litestar` package was inadvertently excluded from this check due to an incorrect regex pattern. Applies fixes for issues identified by the tool. Co-authored-by: Arie Bovenberg --- .github/workflows/ci.yml | 25 +++++++++++++++++++ .pre-commit-config.yaml | 5 ---- Makefile | 8 +++++- litestar/app.py | 4 --- litestar/contrib/opentelemetry/middleware.py | 2 -- litestar/dto/_types.py | 3 --- .../handlers/websocket_handlers/listener.py | 5 ++-- .../websocket_handlers/route_handler.py | 2 ++ litestar/middleware/compression/facade.py | 2 ++ litestar/middleware/cors.py | 2 -- litestar/middleware/logging.py | 2 -- litestar/middleware/rate_limit.py | 2 -- litestar/plugins/base.py | 2 ++ litestar/stores/base.py | 4 +++ litestar/stores/redis.py | 7 +++++- pyproject.toml | 6 +++++ 16 files changed, 56 insertions(+), 25 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bbf0986b56..766279dcb1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -81,6 +81,31 @@ jobs: - name: Run pyright run: pdm run pyright + slotscheck: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.8" + allow-prereleases: false + + - uses: pdm-project/setup-pdm@v4 + name: Set up PDM + with: + python-version: "3.8" + allow-python-prereleases: false + cache: true + cache-dependency-path: | + ./pdm.lock + + - name: Install dependencies + run: pdm install -G:all + + - name: Run slotscheck + run: pdm run slotscheck litestar + test: name: "test (${{ matrix.python-version }})" strategy: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index abdb4aac6d..ecf00f31c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,11 +41,6 @@ repos: - id: ensure-dunder-all exclude: "test*|examples*|tools" args: ["--use-tuple"] - - repo: https://github.com/ariebovenberg/slotscheck - rev: v0.19.0 - hooks: - - id: slotscheck - exclude: "test_*|docs|.github" - repo: https://github.com/sphinx-contrib/sphinx-lint rev: "v0.9.1" hooks: diff --git a/Makefile b/Makefile index 4407d391da..b17df23549 100644 --- a/Makefile +++ b/Makefile @@ -102,8 +102,14 @@ pre-commit: ## Runs pre-commit hooks; includes ruff formatting and lin @$(PDM) run pre-commit run --all-files @echo "=> Pre-commit complete" +.PHONY: slots-check +slots-check: ## Check for slots usage in classes + @echo "=> Checking for slots usage in classes" + @$(PDM) run slotscheck litestar + @echo "=> Slots check complete" + .PHONY: lint -lint: pre-commit type-check ## Run all linting +lint: pre-commit type-check slots-check ## Run all linting .PHONY: coverage coverage: ## Run the tests and generate coverage report diff --git a/litestar/app.py b/litestar/app.py index e1bd989d75..928617e322 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -150,21 +150,17 @@ class Litestar(Router): "csrf_config", "event_emitter", "get_logger", - "include_in_schema", "logger", "logging_config", "multipart_form_part_limit", "on_shutdown", "on_startup", "openapi_config", - "request_class", "response_cache_config", "route_map", - "signature_namespace", "state", "stores", "template_engine", - "websocket_class", "pdb_on_exception", "experimental_features", ) diff --git a/litestar/contrib/opentelemetry/middleware.py b/litestar/contrib/opentelemetry/middleware.py index 762bae9125..59ea4dce06 100644 --- a/litestar/contrib/opentelemetry/middleware.py +++ b/litestar/contrib/opentelemetry/middleware.py @@ -24,8 +24,6 @@ class OpenTelemetryInstrumentationMiddleware(AbstractMiddleware): """OpenTelemetry Middleware.""" - __slots__ = ("open_telemetry_middleware",) - def __init__(self, app: ASGIApp, config: OpenTelemetryConfig) -> None: """Middleware that adds OpenTelemetry instrumentation to the application. diff --git a/litestar/dto/_types.py b/litestar/dto/_types.py index 24e99b793b..b0863b2593 100644 --- a/litestar/dto/_types.py +++ b/litestar/dto/_types.py @@ -96,9 +96,6 @@ class MappingType(CompositeType): @dataclass(frozen=True) class TransferDTOFieldDefinition(DTOFieldDefinition): __slots__ = ( - "default_factory", - "dto_field", - "model_name", "is_excluded", "is_partial", "serialization_name", diff --git a/litestar/handlers/websocket_handlers/listener.py b/litestar/handlers/websocket_handlers/listener.py index 86fefc913a..8e702ea1aa 100644 --- a/litestar/handlers/websocket_handlers/listener.py +++ b/litestar/handlers/websocket_handlers/listener.py @@ -62,11 +62,10 @@ class WebsocketListenerRouteHandler(WebsocketRouteHandler): "connection_accept_handler": "Callback to accept a WebSocket connection. By default, calls WebSocket.accept", "on_accept": "Callback invoked after a WebSocket connection has been accepted", "on_disconnect": "Callback invoked after a WebSocket connection has been closed", - "weboscket_class": "WebSocket class", "_connection_lifespan": None, - "_handle_receive": None, - "_handle_send": None, + "_receive_handler": None, "_receive_mode": None, + "_send_handler": None, "_send_mode": None, } diff --git a/litestar/handlers/websocket_handlers/route_handler.py b/litestar/handlers/websocket_handlers/route_handler.py index edb49c3030..4b8953ee26 100644 --- a/litestar/handlers/websocket_handlers/route_handler.py +++ b/litestar/handlers/websocket_handlers/route_handler.py @@ -18,6 +18,8 @@ class WebsocketRouteHandler(BaseRouteHandler): Use this decorator to decorate websocket handler functions. """ + __slots__ = ("websocket_class",) + def __init__( self, path: str | list[str] | None = None, diff --git a/litestar/middleware/compression/facade.py b/litestar/middleware/compression/facade.py index 0074b57419..a1a62728ce 100644 --- a/litestar/middleware/compression/facade.py +++ b/litestar/middleware/compression/facade.py @@ -12,6 +12,8 @@ class CompressionFacade(Protocol): """A unified facade offering a uniform interface for different compression libraries.""" + __slots__ = () + encoding: ClassVar[str] """The encoding of the compression.""" diff --git a/litestar/middleware/cors.py b/litestar/middleware/cors.py index 6c4de31f8f..010576aa6a 100644 --- a/litestar/middleware/cors.py +++ b/litestar/middleware/cors.py @@ -17,8 +17,6 @@ class CORSMiddleware(AbstractMiddleware): """CORS Middleware.""" - __slots__ = ("config",) - def __init__(self, app: ASGIApp, config: CORSConfig) -> None: """Middleware that adds CORS validation to the application. diff --git a/litestar/middleware/logging.py b/litestar/middleware/logging.py index 0094f10cfa..c986eb6433 100644 --- a/litestar/middleware/logging.py +++ b/litestar/middleware/logging.py @@ -48,8 +48,6 @@ class LoggingMiddleware(AbstractMiddleware): """Logging middleware.""" - __slots__ = ("config", "logger", "request_extractor", "response_extractor", "is_struct_logger") - logger: Logger def __init__(self, app: ASGIApp, config: LoggingMiddlewareConfig) -> None: diff --git a/litestar/middleware/rate_limit.py b/litestar/middleware/rate_limit.py index cd767ba4a2..0c3de7f6e5 100644 --- a/litestar/middleware/rate_limit.py +++ b/litestar/middleware/rate_limit.py @@ -41,8 +41,6 @@ class CacheObject: class RateLimitMiddleware(AbstractMiddleware): """Rate-limiting middleware.""" - __slots__ = ("app", "check_throttle_handler", "max_requests", "unit", "request_quota", "config") - def __init__(self, app: ASGIApp, config: RateLimitConfig) -> None: """Initialize ``RateLimitMiddleware``. diff --git a/litestar/plugins/base.py b/litestar/plugins/base.py index afc571efe7..65710c9fc0 100644 --- a/litestar/plugins/base.py +++ b/litestar/plugins/base.py @@ -212,6 +212,8 @@ def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: S class OpenAPISchemaPlugin(OpenAPISchemaPluginProtocol): """Plugin to extend the support of OpenAPI schema generation for non-library types.""" + __slots__ = () + @staticmethod def is_plugin_supported_type(value: Any) -> bool: """Given a value of indeterminate type, determine if this value is supported by the plugin. diff --git a/litestar/stores/base.py b/litestar/stores/base.py index 34aa514fca..69a63663e5 100644 --- a/litestar/stores/base.py +++ b/litestar/stores/base.py @@ -20,6 +20,8 @@ class Store(ABC): """Thread and process safe asynchronous key/value store.""" + __slots__ = () + @abstractmethod async def set(self, key: str, value: str | bytes, expires_in: int | timedelta | None = None) -> None: """Set a value. @@ -97,6 +99,8 @@ class NamespacedStore(Store): should be isolated. """ + __slots__ = ("namespace",) + @abstractmethod def with_namespace(self, namespace: str) -> Self: """Return a new instance of :class:`NamespacedStore`, which exists in a child namespace of the current namespace. diff --git a/litestar/stores/redis.py b/litestar/stores/redis.py index 6697962fab..4b46097199 100644 --- a/litestar/stores/redis.py +++ b/litestar/stores/redis.py @@ -21,7 +21,12 @@ class RedisStore(NamespacedStore): """Redis based, thread and process safe asynchronous key/value store.""" - __slots__ = ("_redis",) + __slots__ = ( + "_delete_all_script", + "_get_and_renew_script", + "_redis", + "handle_client_shutdown", + ) def __init__( self, redis: Redis, namespace: str | None | EmptyType = Empty, handle_client_shutdown: bool = False diff --git a/pyproject.toml b/pyproject.toml index 1dfc87b68b..3eb1d2489d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -285,6 +285,12 @@ reportUnnecessaryTypeIgnoreComments = true [tool.slotscheck] strict-imports = false +exclude-classes = """ +( + # github.com/python/cpython/pull/106771 + (^litestar.events.emitter:BaseEventEmitterBackend) +) +""" [tool.ruff] lint.select = [ From 043a044f3fd25fa7f38e223eb071b9390c35bc71 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Fri, 29 Mar 2024 18:58:48 +1000 Subject: [PATCH 06/19] fix: pydantic `json_schema_extra` examples. (#3281) Fixes a regression introduced in 2.7 where an example for a field provided in pydantic's `Field.json_schema_extra` would cause an error. Closes #3277 Co-authored-by: avikstroem --- litestar/typing.py | 2 +- .../test_pydantic/test_openapi.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/litestar/typing.py b/litestar/typing.py index 3a275573f3..14a92dc44f 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -96,7 +96,7 @@ def _parse_metadata(value: Any, is_sequence_container: bool, extra: dict[str, An example_list: list[Any] | None if example := extra.pop("example", None): example_list = [Example(value=example)] - elif examples := getattr(value, "examples", None): + elif examples := (extra.pop("examples", None) or getattr(value, "examples", None)): example_list = [Example(value=example) for example in cast("list[str]", examples)] else: example_list = None diff --git a/tests/unit/test_contrib/test_pydantic/test_openapi.py b/tests/unit/test_contrib/test_pydantic/test_openapi.py index aa84def09b..4407b79980 100644 --- a/tests/unit/test_contrib/test_pydantic/test_openapi.py +++ b/tests/unit/test_contrib/test_pydantic/test_openapi.py @@ -553,6 +553,26 @@ class Model(pydantic_v2.BaseModel): assert value.examples == ["example"] +def test_create_schema_for_field_v2__examples() -> None: + class Model(pydantic_v2.BaseModel): + value: str = pydantic_v2.Field( + title="title", description="description", max_length=16, json_schema_extra={"examples": ["example"]} + ) + + schema = get_schema_for_field_definition( + FieldDefinition.from_kwarg(name="Model", annotation=Model), plugins=[PydanticSchemaPlugin()] + ) + + assert schema.properties + + value = schema.properties["value"] + + assert isinstance(value, Schema) + assert value.description == "description" + assert value.title == "title" + assert value.examples == ["example"] + + @pytest.mark.parametrize("with_future_annotations", [True, False]) def test_create_schema_for_pydantic_model_with_annotated_model_attribute( with_future_annotations: bool, create_module: "Callable[[str], ModuleType]", pydantic_version: PydanticVersion From 7ec88035af73cad4a6318405e63e3b0a8c9aa828 Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Tue, 12 Mar 2024 22:09:05 -0500 Subject: [PATCH 07/19] feat: allow for console output to be silenced (#3180) In some cases, I've wanted to change the name of the "Application" to show something else instead of "Litestar". This is usually to make a CLI feel more cohesive and part of a larger application. In the same vein, I've found cases where I wanted to complete suppress the initial `from_env` or the Rich info table at startup. --- litestar/cli/_utils.py | 5 +- litestar/cli/commands/core.py | 8 +-- tests/unit/test_cli/test_core_commands.py | 60 +++++++++++++++++++++-- 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/litestar/cli/_utils.py b/litestar/cli/_utils.py index f36cd7703c..84ebd3b1af 100644 --- a/litestar/cli/_utils.py +++ b/litestar/cli/_utils.py @@ -107,10 +107,13 @@ def from_env(cls, app_path: str | None, app_dir: Path | None = None) -> Litestar dotenv.load_dotenv() app_path = app_path or getenv("LITESTAR_APP") + app_name = getenv("LITESTAR_APP_NAME") or "Litestar" + quiet_console = getenv("LITESTAR_QUIET_CONSOLE") or False if app_path and getenv("LITESTAR_APP") is None: os.environ["LITESTAR_APP"] = app_path if app_path: - console.print(f"Using Litestar app from env: [bright_blue]{app_path!r}") + if not quiet_console: + console.print(f"Using {app_name} from env: [bright_blue]{app_path!r}") loaded_app = _load_app_from_path(app_path) else: loaded_app = _autodiscover_app(cwd) diff --git a/litestar/cli/commands/core.py b/litestar/cli/commands/core.py index 5b552533b1..06e5d4175f 100644 --- a/litestar/cli/commands/core.py +++ b/litestar/cli/commands/core.py @@ -189,7 +189,7 @@ def run_command( if pdb: os.environ["LITESTAR_PDB"] = "1" - + quiet_console = os.getenv("LITESTAR_QUIET_CONSOLE") or False if not UVICORN_INSTALLED: console.print( r"uvicorn is not installed. Please install the standard group, litestar\[standard], to use this command." @@ -228,9 +228,9 @@ def run_command( else validate_ssl_file_paths(ssl_certfile, ssl_keyfile) ) - console.rule("[yellow]Starting server process", align="left") - - show_app_info(app) + if not quiet_console: + console.rule("[yellow]Starting server process", align="left") + show_app_info(app) with _server_lifespan(app): if workers == 1 and not reload: import uvicorn diff --git a/tests/unit/test_cli/test_core_commands.py b/tests/unit/test_cli/test_core_commands.py index e2acc93c67..4292e25b41 100644 --- a/tests/unit/test_cli/test_core_commands.py +++ b/tests/unit/test_cli/test_core_commands.py @@ -1,3 +1,4 @@ +import io import os import re import sys @@ -9,9 +10,10 @@ from _pytest.monkeypatch import MonkeyPatch from click.testing import CliRunner from pytest_mock import MockerFixture +from rich.console import Console from litestar import __version__ as litestar_version -from litestar.cli._utils import remove_default_schema_routes, remove_routes_with_patterns +from litestar.cli import _utils from litestar.cli.main import litestar_group as cli_command from litestar.exceptions import LitestarWarning @@ -278,6 +280,58 @@ def test_run_command_debug( assert os.getenv("LITESTAR_DEBUG") == "1" +@pytest.mark.usefixtures("mock_uvicorn_run", "unset_env") +def test_run_command_quiet_console( + app_file: Path, runner: CliRunner, monkeypatch: MonkeyPatch, create_app_file: CreateAppFileFixture +) -> None: + console = Console(file=io.StringIO()) + monkeypatch.setattr(_utils, "console", console) + + path = create_app_file("_create_app_with_path.py", content=CREATE_APP_FILE_CONTENT) + app_path = f"{path.stem}:create_app" + monkeypatch.delenv("LITESTAR_QUIET_CONSOLE", raising=False) + result = runner.invoke(cli_command, ["--app", app_path, "run"]) + assert result.exit_code == 0 + normal_output = console.file.getvalue() # type: ignore[attr-defined] + assert "Using Litestar from env:" in normal_output + assert "Starting server process" in result.stdout + del result + console = Console(file=io.StringIO()) + monkeypatch.setattr(_utils, "console", console) + monkeypatch.setenv("LITESTAR_QUIET_CONSOLE", "1") + assert os.getenv("LITESTAR_QUIET_CONSOLE") == "1" + result = runner.invoke(cli_command, ["--app", app_path, "run"]) + assert result.exit_code == 0 + quiet_output = console.file.getvalue() # type: ignore[attr-defined] + assert "Starting server process" not in result.stdout + assert "Using Litestar from env:" not in quiet_output + console.clear() + + +@pytest.mark.usefixtures("mock_uvicorn_run", "unset_env") +def test_run_command_custom_app_name( + app_file: Path, runner: CliRunner, monkeypatch: MonkeyPatch, create_app_file: CreateAppFileFixture +) -> None: + console = Console(file=io.StringIO()) + monkeypatch.setattr(_utils, "console", console) + + path = create_app_file("_create_app_with_path.py", content=CREATE_APP_FILE_CONTENT) + app_path = f"{path.stem}:create_app" + monkeypatch.delenv("LITESTAR_APP_NAME", raising=False) + result = runner.invoke(cli_command, ["--app", app_path, "run"]) + assert result.exit_code == 0 + _output = console.file.getvalue() # type: ignore[attr-defined] + assert "Using Litestar from env:" in _output + console = Console(file=io.StringIO()) + monkeypatch.setattr(_utils, "console", console) + monkeypatch.setenv("LITESTAR_APP_NAME", "My Stuff") + assert os.getenv("LITESTAR_APP_NAME") == "My Stuff" + result = runner.invoke(cli_command, ["--app", app_path, "run"]) + assert result.exit_code == 0 + _output = console.file.getvalue() # type: ignore[attr-defined] + assert "Using My Stuff from env:" in _output + + @pytest.mark.usefixtures("mock_uvicorn_run", "unset_env") def test_run_command_pdb( app_file: Path, @@ -425,7 +479,7 @@ def test_remove_default_schema_routes() -> None: api_config = MagicMock() api_config.openapi_controller.path = "/schema" - results = remove_default_schema_routes(http_routes, api_config) # type: ignore[arg-type] + results = _utils.remove_default_schema_routes(http_routes, api_config) # type: ignore[arg-type] assert len(results) == 3 for result in results: words = re.split(r"(^\/[a-z]+)", result.path) @@ -441,7 +495,7 @@ def test_remove_routes_with_patterns() -> None: http_routes.append(http_route) patterns = ("/destroy", "/pizza", "[]") - results = remove_routes_with_patterns(http_routes, patterns) # type: ignore[arg-type] + results = _utils.remove_routes_with_patterns(http_routes, patterns) # type: ignore[arg-type] paths = [route.path for route in results] assert len(paths) == 2 for route in ["/", "/foo"]: From dd8efe82914d3e079e02cac3f599b27e4f7d701d Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 14 Mar 2024 06:40:51 +0100 Subject: [PATCH 08/19] feat: add flash plugin (#3145) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> Co-authored-by: Jacob Coffee --- .../plugins/flash_messages/__init__.py | 0 docs/examples/plugins/flash_messages/jinja.py | 9 +++ docs/examples/plugins/flash_messages/mako.py | 9 +++ .../plugins/flash_messages/minijinja.py | 9 +++ docs/examples/plugins/flash_messages/usage.py | 26 ++++++ docs/index.rst | 2 +- docs/reference/plugins/flash_messages.rst | 7 ++ docs/reference/plugins/index.rst | 1 + docs/usage/index.rst | 2 +- docs/usage/plugins/flash_messages.rst | 73 +++++++++++++++++ docs/usage/{plugins.rst => plugins/index.rst} | 8 +- docs/usage/requests.rst | 2 +- litestar/contrib/minijinja.py | 10 ++- litestar/plugins/flash.py | 74 +++++++++++++++++ litestar/utils/scope/state.py | 3 + tests/unit/test_plugins/test_flash.py | 80 +++++++++++++++++++ 16 files changed, 310 insertions(+), 5 deletions(-) create mode 100644 docs/examples/plugins/flash_messages/__init__.py create mode 100644 docs/examples/plugins/flash_messages/jinja.py create mode 100644 docs/examples/plugins/flash_messages/mako.py create mode 100644 docs/examples/plugins/flash_messages/minijinja.py create mode 100644 docs/examples/plugins/flash_messages/usage.py create mode 100644 docs/reference/plugins/flash_messages.rst create mode 100644 docs/usage/plugins/flash_messages.rst rename docs/usage/{plugins.rst => plugins/index.rst} (97%) create mode 100644 litestar/plugins/flash.py create mode 100644 tests/unit/test_plugins/test_flash.py diff --git a/docs/examples/plugins/flash_messages/__init__.py b/docs/examples/plugins/flash_messages/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/plugins/flash_messages/jinja.py b/docs/examples/plugins/flash_messages/jinja.py new file mode 100644 index 0000000000..6772697642 --- /dev/null +++ b/docs/examples/plugins/flash_messages/jinja.py @@ -0,0 +1,9 @@ +from litestar import Litestar +from litestar.contrib.jinja import JinjaTemplateEngine +from litestar.plugins.flash import FlashConfig, FlashPlugin +from litestar.template.config import TemplateConfig + +template_config = TemplateConfig(engine=JinjaTemplateEngine, directory="templates") +flash_plugin = FlashPlugin(config=FlashConfig(template_config=template_config)) + +app = Litestar(plugins=[flash_plugin]) diff --git a/docs/examples/plugins/flash_messages/mako.py b/docs/examples/plugins/flash_messages/mako.py new file mode 100644 index 0000000000..a5ce038eab --- /dev/null +++ b/docs/examples/plugins/flash_messages/mako.py @@ -0,0 +1,9 @@ +from litestar import Litestar +from litestar.contrib.mako import MakoTemplateEngine +from litestar.plugins.flash import FlashConfig, FlashPlugin +from litestar.template.config import TemplateConfig + +template_config = TemplateConfig(engine=MakoTemplateEngine, directory="templates") +flash_plugin = FlashPlugin(config=FlashConfig(template_config=template_config)) + +app = Litestar(plugins=[flash_plugin]) diff --git a/docs/examples/plugins/flash_messages/minijinja.py b/docs/examples/plugins/flash_messages/minijinja.py new file mode 100644 index 0000000000..0ea2ce0f8e --- /dev/null +++ b/docs/examples/plugins/flash_messages/minijinja.py @@ -0,0 +1,9 @@ +from litestar import Litestar +from litestar.contrib.minijinja import MiniJinjaTemplateEngine +from litestar.plugins.flash import FlashConfig, FlashPlugin +from litestar.template.config import TemplateConfig + +template_config = TemplateConfig(engine=MiniJinjaTemplateEngine, directory="templates") +flash_plugin = FlashPlugin(config=FlashConfig(template_config=template_config)) + +app = Litestar(plugins=[flash_plugin]) diff --git a/docs/examples/plugins/flash_messages/usage.py b/docs/examples/plugins/flash_messages/usage.py new file mode 100644 index 0000000000..914919ea0f --- /dev/null +++ b/docs/examples/plugins/flash_messages/usage.py @@ -0,0 +1,26 @@ +from litestar import Litestar, Request, get +from litestar.contrib.jinja import JinjaTemplateEngine +from litestar.plugins.flash import FlashConfig, FlashPlugin, flash +from litestar.response import Template +from litestar.template.config import TemplateConfig + +template_config = TemplateConfig(engine=JinjaTemplateEngine, directory="templates") +flash_plugin = FlashPlugin(config=FlashConfig(template_config=template_config)) + + +@get() +async def index(request: Request) -> Template: + """Example of adding and displaying a flash message.""" + flash(request, "Oh no! I've been flashed!", category="error") + + return Template( + template_str=""" +

Flash Message Example

+ {% for message in get_flashes() %} +

{{ message.message }} (Category:{{ message.category }})

+ {% endfor %} + """ + ) + + +app = Litestar(plugins=[flash_plugin], route_handlers=[index], template_config=template_config) diff --git a/docs/index.rst b/docs/index.rst index d8c2822765..11b2150d66 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,7 @@ Litestar library documentation Litestar is a powerful, flexible, highly performant, and opinionated ASGI framework. -The Litestar framework supports :doc:`/usage/plugins`, ships +The Litestar framework supports :doc:`/usage/plugins/index`, ships with :doc:`dependency injection `, :doc:`security primitives `, :doc:`OpenAPI schema generation `, `MessagePack `_, :doc:`middlewares `, a great :doc:`CLI ` experience, and much more. diff --git a/docs/reference/plugins/flash_messages.rst b/docs/reference/plugins/flash_messages.rst new file mode 100644 index 0000000000..34d1b411d5 --- /dev/null +++ b/docs/reference/plugins/flash_messages.rst @@ -0,0 +1,7 @@ +===== +flash +===== + + +.. automodule:: litestar.plugins.flash + :members: diff --git a/docs/reference/plugins/index.rst b/docs/reference/plugins/index.rst index 5de973df3c..128fdf0302 100644 --- a/docs/reference/plugins/index.rst +++ b/docs/reference/plugins/index.rst @@ -9,5 +9,6 @@ plugins :maxdepth: 1 :hidden: + flash_messages structlog sqlalchemy diff --git a/docs/usage/index.rst b/docs/usage/index.rst index 632375c475..b19abc4ca4 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -22,7 +22,7 @@ Usage metrics/index middleware/index openapi - plugins + plugins/index responses security/index static-files diff --git a/docs/usage/plugins/flash_messages.rst b/docs/usage/plugins/flash_messages.rst new file mode 100644 index 0000000000..8ff46b8db6 --- /dev/null +++ b/docs/usage/plugins/flash_messages.rst @@ -0,0 +1,73 @@ +============== +Flash Messages +============== + +.. versionadded:: 2.7.0 + +Flash messages are a powerful tool for conveying information to the user, +such as success notifications, warnings, or errors through one-time messages alongside a response due +to some kind of user action. + +They are typically used to display a message on the next page load and are a great way +to enhance user experience by providing immediate feedback on their actions from things like form submissions. + +Registering the plugin +---------------------- + +The FlashPlugin can be easily integrated with different templating engines. +Below are examples of how to register the ``FlashPlugin`` with ``Jinja2``, ``Mako``, and ``MiniJinja`` templating engines. + +.. tab-set:: + + .. tab-item:: Jinja2 + :sync: jinja + + .. literalinclude:: /examples/plugins/flash_messages/jinja.py + :language: python + :caption: Registering the flash message plugin using the Jinja2 templating engine + + .. tab-item:: Mako + :sync: mako + + .. literalinclude:: /examples/plugins/flash_messages/mako.py + :language: python + :caption: Registering the flash message plugin using the Mako templating engine + + .. tab-item:: MiniJinja + :sync: minijinja + + .. literalinclude:: /examples/plugins/flash_messages/minijinja.py + :language: python + :caption: Registering the flash message plugin using the MiniJinja templating engine + +Using the plugin +---------------- + +After registering the FlashPlugin with your application, you can start using it to add and display +flash messages within your application routes. + +Here is an example showing how to use the FlashPlugin with the Jinja2 templating engine to display flash messages. +The same approach applies to Mako and MiniJinja engines as well. + +.. literalinclude:: /examples/plugins/flash_messages/usage.py + :language: python + :caption: Using the flash message plugin with Jinja2 templating engine to display flash messages + +Breakdown ++++++++++ + +#. Here we import the requires classes and functions from the Litestar package and related plugins. +#. We then create our ``TemplateConfig`` and ``FlashConfig`` instances, each setting up the configuration for + the template engine and flash messages, respectively. +#. A single route handler named ``index`` is defined using the ``@get()`` decorator. + + * Within this handler, the ``flash`` function is called to add a new flash message. + This message is stored in the request's context, making it accessible to the template engine for rendering in the response. + * The function returns a ``Template`` instance, where ``template_str`` + (read more about :ref:`template strings `) + contains inline HTML and Jinja2 template code. + This template dynamically displays any flash messages by iterating over them with a Jinja2 for loop. + Each message is wrapped in a paragraph (``

``) tag, showing the message content and its category. + +#. Finally, a ``Litestar`` application instance is created, specifying the ``flash_plugin`` and ``index`` route handler in its configuration. + The application is also configured with the ``template_config``, which includes the ``Jinja2`` templating engine and the path to the templates directory. diff --git a/docs/usage/plugins.rst b/docs/usage/plugins/index.rst similarity index 97% rename from docs/usage/plugins.rst rename to docs/usage/plugins/index.rst index 4911b8da23..ff0cdc1652 100644 --- a/docs/usage/plugins.rst +++ b/docs/usage/plugins/index.rst @@ -1,3 +1,4 @@ +======= Plugins ======= @@ -84,7 +85,7 @@ Example The following example shows the actual implementation of the ``SerializationPluginProtocol`` for `SQLAlchemy `_ models that is is provided in ``advanced_alchemy``. -.. literalinclude:: ../../litestar/contrib/sqlalchemy/plugins/serialization.py +.. literalinclude:: ../../../litestar/contrib/sqlalchemy/plugins/serialization.py :language: python :caption: ``SerializationPluginProtocol`` implementation example @@ -123,3 +124,8 @@ signature (their :func:`__init__` method). .. literalinclude:: /examples/plugins/di_plugin.py :language: python :caption: Dynamically generating signature information for a custom type + +.. toctree:: + :titlesonly: + + flash_messages diff --git a/docs/usage/requests.rst b/docs/usage/requests.rst index fdc65f2784..485748ac8c 100644 --- a/docs/usage/requests.rst +++ b/docs/usage/requests.rst @@ -17,7 +17,7 @@ The type of ``data`` an be any supported type, including * :class:`TypedDicts ` * Pydantic models * Arbitrary stdlib types -* Typed supported via :doc:`plugins ` +* Typed supported via :doc:`plugins ` .. literalinclude:: /examples/request_data/request_data_2.py :language: python diff --git a/litestar/contrib/minijinja.py b/litestar/contrib/minijinja.py index 6007a18180..1fcd14bc10 100644 --- a/litestar/contrib/minijinja.py +++ b/litestar/contrib/minijinja.py @@ -159,7 +159,9 @@ def get_template(self, template_name: str) -> MiniJinjaTemplate: return MiniJinjaTemplate(self.engine, template_name) def register_template_callable( - self, key: str, template_callable: TemplateCallableType[StateProtocol, P, T] + self, + key: str, + template_callable: TemplateCallableType[StateProtocol, P, T], ) -> None: """Register a callable on the template engine. @@ -170,6 +172,12 @@ def register_template_callable( Returns: None """ + + def is_decorated(func: Callable) -> bool: + return hasattr(func, "__wrapped__") or func.__name__ not in globals() + + if not is_decorated(template_callable): + template_callable = _transform_state(template_callable) # type: ignore[arg-type] # pragma: no cover self.engine.add_global(key, pass_state(template_callable)) def render_string(self, template_string: str, context: Mapping[str, Any]) -> str: diff --git a/litestar/plugins/flash.py b/litestar/plugins/flash.py new file mode 100644 index 0000000000..6b61040120 --- /dev/null +++ b/litestar/plugins/flash.py @@ -0,0 +1,74 @@ +"""Plugin for creating and retrieving flash messages.""" +from dataclasses import dataclass +from typing import Any, Mapping + +from litestar.config.app import AppConfig +from litestar.connection import ASGIConnection +from litestar.contrib.minijinja import MiniJinjaTemplateEngine +from litestar.plugins import InitPluginProtocol +from litestar.template import TemplateConfig +from litestar.template.base import _get_request_from_context +from litestar.utils.scope.state import ScopeState + + +@dataclass +class FlashConfig: + """Configuration for Flash messages.""" + + template_config: TemplateConfig + + +class FlashPlugin(InitPluginProtocol): + """Flash messages Plugin.""" + + def __init__(self, config: FlashConfig): + """Initialize the plugin. + + Args: + config: Configuration for flash messages, including the template engine instance. + """ + self.config = config + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Register the message callable on the template engine instance. + + Args: + app_config: The application configuration. + + Returns: + The application configuration with the message callable registered. + """ + if isinstance(self.config.template_config.engine_instance, MiniJinjaTemplateEngine): + from litestar.contrib.minijinja import _transform_state + + self.config.template_config.engine_instance.register_template_callable( + "get_flashes", _transform_state(get_flashes) + ) + else: + self.config.template_config.engine_instance.register_template_callable("get_flashes", get_flashes) + return app_config + + +def flash(connection: ASGIConnection, message: str, category: str) -> None: + """Add a flash message to the request scope. + + Args: + connection: The connection instance. + message: The message to flash. + category: The category of the message. + """ + scope_state = ScopeState.from_scope(connection.scope) + scope_state.flash_messages.append({"message": message, "category": category}) + + +def get_flashes(context: Mapping[str, Any]) -> Any: + """Get flash messages from the request scope, if any. + + Args: + context: The context dictionary. + + Returns: + The flash messages, if any. + """ + scope_state = ScopeState.from_scope(_get_request_from_context(context).scope) + return scope_state.flash_messages diff --git a/litestar/utils/scope/state.py b/litestar/utils/scope/state.py index bed43940e2..2799915a19 100644 --- a/litestar/utils/scope/state.py +++ b/litestar/utils/scope/state.py @@ -33,6 +33,7 @@ class ScopeState: "csrf_token", "dependency_cache", "do_cache", + "flash_messages", "form", "headers", "is_cached", @@ -56,6 +57,7 @@ def __init__(self) -> None: self.dependency_cache = Empty self.do_cache = Empty self.form = Empty + self.flash_messages = [] self.headers = Empty self.is_cached = Empty self.json = Empty @@ -76,6 +78,7 @@ def __init__(self) -> None: dependency_cache: dict[str, Any] | EmptyType do_cache: bool | EmptyType form: dict[str, str | list[str]] | EmptyType + flash_messages: list[dict[str, str]] headers: Headers | EmptyType is_cached: bool | EmptyType json: Any | EmptyType diff --git a/tests/unit/test_plugins/test_flash.py b/tests/unit/test_plugins/test_flash.py new file mode 100644 index 0000000000..2a283b63fb --- /dev/null +++ b/tests/unit/test_plugins/test_flash.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from enum import Enum +from pathlib import Path + +import pytest + +from litestar import Request, get +from litestar.contrib.jinja import JinjaTemplateEngine +from litestar.contrib.mako import MakoTemplateEngine +from litestar.contrib.minijinja import MiniJinjaTemplateEngine +from litestar.plugins.flash import FlashConfig, FlashPlugin, flash +from litestar.response import Template +from litestar.template import TemplateConfig, TemplateEngineProtocol +from litestar.testing import create_test_client + +text_html_jinja = """{% for message in get_flashes() %}{{ message.message }}{% endfor %}""" +text_html_mako = """<% messages = get_flashes() %>\\ +% for m in messages: +${m['message']}\\ +% endfor +""" + + +class CustomCategory(str, Enum): + custom1 = "1" + custom2 = "2" + custom3 = "3" + + +class FlashCategory(str, Enum): + info = "INFO" + error = "ERROR" + warning = "WARNING" + success = "SUCCESS" + + +@pytest.mark.parametrize( + "engine, template_str", + ( + (JinjaTemplateEngine, text_html_jinja), + (MakoTemplateEngine, text_html_mako), + (MiniJinjaTemplateEngine, text_html_jinja), + ), + ids=("jinja", "mako", "minijinja"), +) +@pytest.mark.parametrize( + "category_enum", + (CustomCategory, FlashCategory), + ids=("custom_category", "flash_category"), +) +def test_flash_plugin( + tmp_path: Path, + engine: type[TemplateEngineProtocol], + template_str: str, + category_enum: Enum, +) -> None: + Path(tmp_path / "flash.html").write_text(template_str) + text_expected = "".join( + [f'message {category.value}' for category in category_enum] # type: ignore[attr-defined] + ) + + @get("/flash") + def flash_handler(request: Request) -> Template: + for category in category_enum: # type: ignore[attr-defined] + flash(request, f"message {category.value}", category=category.value) + return Template("flash.html") + + template_config: TemplateConfig = TemplateConfig( + directory=Path(tmp_path), + engine=engine, + ) + with create_test_client( + [flash_handler], + template_config=template_config, + plugins=[FlashPlugin(config=FlashConfig(template_config=template_config))], + ) as client: + r = client.get("/flash") + assert r.status_code == 200 + assert r.text == text_expected From eb52f471877ba7b273730e58235e85398935f625 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sat, 16 Mar 2024 10:12:12 +0100 Subject: [PATCH 09/19] chore: Fix formatting for `develop` (#3214) Fix formatting --- litestar/plugins/flash.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litestar/plugins/flash.py b/litestar/plugins/flash.py index 6b61040120..52b5aa1685 100644 --- a/litestar/plugins/flash.py +++ b/litestar/plugins/flash.py @@ -1,4 +1,5 @@ """Plugin for creating and retrieving flash messages.""" + from dataclasses import dataclass from typing import Any, Mapping From aa8274ff37f2093e794328eabfddb542a07644de Mon Sep 17 00:00:00 2001 From: kedod <35638715+kedod@users.noreply.github.com> Date: Sat, 16 Mar 2024 10:23:44 +0100 Subject: [PATCH 10/19] feat: Use memoized `request_class` and `response_class` values (#3205) --- litestar/handlers/http_handlers/base.py | 27 +++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index 757253e675..2d5d8322ff 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -78,6 +78,8 @@ class HTTPRouteHandler(BaseRouteHandler): "_resolved_before_request", "_response_handler_mapping", "_resolved_include_in_schema", + "_resolved_response_class", + "_resolved_request_class", "_resolved_tags", "_resolved_security", "after_request", @@ -290,6 +292,8 @@ def __init__( self._resolved_before_request: AsyncAnyCallable | None | EmptyType = Empty self._response_handler_mapping: ResponseHandlerMap = {"default_handler": Empty, "response_type_handler": Empty} self._resolved_include_in_schema: bool | EmptyType = Empty + self._resolved_response_class: type[Response] | EmptyType = Empty + self._resolved_request_class: type[Request] | EmptyType = Empty self._resolved_security: list[SecurityRequirement] | EmptyType = Empty self._resolved_tags: list[str] | EmptyType = Empty @@ -312,10 +316,14 @@ def resolve_request_class(self) -> type[Request]: Returns: The default :class:`Request <.connection.Request>` class for the route handler. """ - return next( - (layer.request_class for layer in reversed(self.ownership_layers) if layer.request_class is not None), - Request, - ) + + if self._resolved_request_class is Empty: + self._resolved_request_class = next( + (layer.request_class for layer in reversed(self.ownership_layers) if layer.request_class is not None), + Request, + ) + + return cast("type[Request]", self._resolved_request_class) def resolve_response_class(self) -> type[Response]: """Return the closest custom Response class in the owner graph or the default Response class. @@ -325,10 +333,13 @@ def resolve_response_class(self) -> type[Response]: Returns: The default :class:`Response <.response.Response>` class for the route handler. """ - return next( - (layer.response_class for layer in reversed(self.ownership_layers) if layer.response_class is not None), - Response, - ) + if self._resolved_response_class is Empty: + self._resolved_response_class = next( + (layer.response_class for layer in reversed(self.ownership_layers) if layer.response_class is not None), + Response, + ) + + return cast("type[Response]", self._resolved_response_class) def resolve_response_headers(self) -> frozenset[ResponseHeader]: """Return all header parameters in the scope of the handler function. From ff0be36c7f2706439eda5e74ac558d71a1666ff6 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Sun, 17 Mar 2024 12:39:46 +1000 Subject: [PATCH 11/19] refactor: openapi router (#2893) * refactor: openapi router This PR refactors the way that we support multiple UIs for OpenAPI. We add `litestar.openapi.plugins` where `OpenAPIRenderPlugin` is defined, and implementations of that plugin for the frameworks we currently support. We add `OpenAPIConfig.render_plugins` config option, where a user can explicitly declare a set of plugins for UIs they wish to support. If a user declares a sub-class of `OpenAPIController` at `OpenAPIConfig.openapi_controller`, then existing behavior is preserved exactly. However, if no controller is explicitly declared, we invoke the new router-based approach, which should behave identically to the controller based approach (i.e., respect `enabled_endpoints` and `root_schema_site`). Closes #2541 * docs: start of documentation re-write. - creates an indexed directory for openapi - removes the controller docs - start of docs for plugins * refactor: move JsonRenderPlugin into private ns We add the json plugin, and have hardcoded refs to the path that it serves, so best not to make this public API just yet. * docs: reference docs for plugins * Revert "refactor: move JsonRenderPlugin into private ns" This reverts commit 60719aabd31255ff315eb6dda768478dee371b8f. * docs: JsonRenderPlugin undocumented. * docs: continue plugin docs * test: run tests for both plugin and controller Modifies tests where appropriate to run on both the plugin-based approach and the controller based approach. * Implement default endpoint selection logic. * Deprecation of OpenAPIController configs * docs: swagger oauth examples * Update docs/usage/openapi/ui_plugins.rst * Update docs/usage/openapi/ui_plugins.rst * fix: linting * refactor: don't rely on DI for openapi schema in plugin handler. * fix(test): there's an extra schema route to serve 404s. * fix(docs): docstring indent * fix(lint): remove redundant return * refactor: plugins receive style tag instead of tag content. * feat: allow openapi router to be handed to openapi config. Allows for customization, such as adding guards, middleware, other routes, etc. * feat: add `scalar` schema ui (#2906) * Update litestar/openapi/plugins.py Co-authored-by: Jacob Coffee * Update litestar/openapi/plugins.py Co-authored-by: Jacob Coffee * Update litestar/openapi/plugins.py Co-authored-by: Jacob Coffee * Update litestar/openapi/config.py Co-authored-by: Jacob Coffee * Update litestar/openapi/config.py Co-authored-by: Jacob Coffee * fix: update deprecation version * fix: use GH repo for scalar links * fix: update default scalar version * fix: scalar plugin style attribute render. Plugins expect that the style value is already wrapped in `.`` + +Most plugins support the following additional options: + +- ``version``: The version of the UIs JS and (in some cases) CSS bundle to use. We use the ``version`` to construct the + URL to retrieve the the bundle from ``unpkg``, e.g., ``https://unpkg.com/rapidoc@/dist/rapidoc-min.js`` +- ``js_url``: The URL to the JS bundle. If provided, this will override the ``version`` option. +- ``css_url``: The URL to the CSS bundle. If provided, this will override the ``version`` option. + +Here's some example plugin configurations: + +.. tab-set:: + + .. tab-item:: scalar + :sync: scalar + + .. literalinclude:: /examples/openapi/plugins/scalar_config.py + :language: python + + .. tab-item:: rapidoc + :sync: rapidoc + + .. literalinclude:: /examples/openapi/plugins/rapidoc_config.py + :language: python + + .. tab-item:: redoc + :sync: redoc + + .. literalinclude:: /examples/openapi/plugins/redoc_config.py + :language: python + + .. tab-item:: stoplight + :sync: stoplight + + .. literalinclude:: /examples/openapi/plugins/stoplight_config.py + :language: python + + .. tab-item:: swagger + :sync: swagger + + .. literalinclude:: /examples/openapi/plugins/swagger_ui_config.py + :language: python + +Configuring the OpenAPI Root Path +--------------------------------- + +The OpenAPI root path is the path at which the OpenAPI representations are served. By default, this is ``/schema``. +This can be changed by setting the :attr:`OpenAPIConfig.path` attribute. + +In the following example, we configure the OpenAPI root path to be ``/docs``: + +.. literalinclude:: /examples/openapi/customize_path.py + :language: python + +This will result in any of the OpenAPI endpoints being served at ``/docs`` instead of ``/schema``, e.g., +``/docs/openapi.json``. + +Backward Compatibility +---------------------- + +OpenAPI UI plugins are a new feature introduced in ``v2.8.0``. + +Providing a subclass of OpenAPIController +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. deprecated:: v2.8.0 + +The previous method of configuring elements such as the root path and styling was to subclass +:class:`OpenAPIController`, and set it on the :attr:`OpenAPIConfig.openapi_controller` attribute. This approach is now +deprecated and slated for removal in ``v3.0.0``, but if you are using it, there should be no change in behavior. + +To maintain backward compatibility with the previous approach, if neither the :attr:`OpenAPIConfig.openapi_controller` +or :attr:`OpenAPIConfig.render_plugins` attributes are set, we will automatically add the plugins to respect the also +deprecated :attr:`OpenAPIConfig.enabled_endpoints` attribute. By default, this will result in the following endpoints +being enabled: + +- ``/schema/openapi.json`` +- ``/schema/redoc`` +- ``/schema/rapidoc`` +- ``/schema/elements`` +- ``/schema/swagger`` +- ``/schema/openapi.yml`` +- ``/schema/openapi.yaml`` + +In ``v3.0.0``, the :attr:`OpenAPIConfig.enabled_endpoints` attribute will be removed, and only a single UI plugin will be +enabled by default, in addition to the ``openapi.json`` endpoint which will always be enabled. ``Scalar`` will also +become the default UI plugin in ``v3.0.0``. + +To adopt the future behavior, explicitly set the :attr:`OpenAPIConfig.render_plugins` field to an instance of +:class:`ScalarRenderPlugin`: + +.. literalinclude:: /examples/openapi/plugins/scalar_simple.py + :language: python + :lines: 13-21 + +Backward compatibility with ``root_schema_site`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Litestar has always supported a ``root_schema_site`` attribute on the :class:`OpenAPIConfig` class. This attribute +allows you to elect to serve a UI at the OpenAPI root path, e.g., by default ``redoc`` would be served at both +``/schema`` and ``/schema/redoc``. + +In ``v3.0.0``, the ``root_schema_site`` attribute will be removed, and the first :class:`OpenAPIRenderPlugin` in the +:attr:`OpenAPIConfig.render_plugins` list will be assigned to the ``/schema`` endpoint. + +As of ``v2.8.0``, if you explicitly use the new :attr:`OpenAPIConfig.render_plugins` attribute, you will be +automatically opted in to the new behavior, and the ``root_schema_site`` attribute will be ignored. + +Building your own OpenAPI UI Plugin +----------------------------------- + +If Litestar does not have built-in support for your OpenAPI UI framework of choice, you can easily create your own +plugin by subclassing :class:`OpenAPIRenderPlugin` and implementing the :meth:`OpenAPIRenderPlugin.render` method. + +To demonstrate building a custom plugin, we'll look at a plugin very similar to the :class:`ScalarRenderPlugin` that is +maintained by Litestar. Here's the finished product: + +.. literalinclude:: /examples/openapi/plugins/custom_plugin.py + :language: python + +Class definition +~~~~~~~~~~~~~~~~ + +The class ``ScalarRenderPlugin`` inherits from :class:`OpenAPIRenderPlugin`: + +.. literalinclude:: /examples/openapi/plugins/custom_plugin.py + :language: python + :lines: 10 + +``__init__`` Constructor +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: /examples/openapi/plugins/custom_plugin.py + :language: python + :lines: 11-22 + +We support configuration via the following arguments: + +- ``version``: Specifies the version of RapiDoc to use. +- ``js_url``: Custom URL to the RapiDoc JavaScript bundle. +- ``css_url``: Custom URL to the RapiDoc CSS bundle. +- ``path``: The URL path where the RapiDoc UI will be served. +- ``**kwargs``: Captures additional arguments to pass to the superclass. + +And we construct a url for the Scalar JavaScript bundle if one is not provided: + +.. literalinclude:: /examples/openapi/plugins/custom_plugin.py + :language: python + :lines: 20 + +``render()`` +~~~~~~~~~~~~ + +.. literalinclude:: /examples/openapi/plugins/custom_plugin.py + :language: python + :lines: 24 + +Finally we define the ``render`` method, which is called by Litestar to render the UI. It receives the a +:class:`Request` object and the ``openapi_schema`` as a dictionary. + +Inside the ``render`` method, we construct the HTML to render the UI, and return it as a string. + +- ``head``: Defines the HTML ```` section, including the title from ``openapi_schema``, any additional styles + (``self.style``), the favicon and custom style sheet if one is provided: + + .. literalinclude:: /examples/openapi/plugins/custom_plugin.py + :language: python + :lines: 25-35 + +- ``body``: Constructs the HTML ````, including a link to the OpenAPI JSON, and the JavaScript bundle: + + .. literalinclude:: /examples/openapi/plugins/custom_plugin.py + :language: python + :lines: 37-43 + +- Finally, returns a complete HTML document (as a byte string), combining head and body. + + .. literalinclude:: /examples/openapi/plugins/custom_plugin.py + :language: python + :lines: 45-51 + +Interacting with the ``Router`` +------------------------------- + +An instance of :class:`Router` is used to serve the OpenAPI endpoints and is made available to plugins via the +:meth:`OpenAPIRenderPlugin.receive_router` method. + +This can be used for a variety of purposes, including adding additional routes to the ``Router``. + +.. literalinclude:: /examples/openapi/plugins/receive_router.py + :language: python + +OAuth2 in Swagger UI +-------------------- + +When using Swagger, OAuth2 settings can be configured via +:attr:`swagger_ui_init_oauth `, which can be set to +a dictionary containing the parameters described in the Swagger UI documentation for OAuth2 +`here `_. + +With that, you can preset your clientId or enable PKCE support. + +.. literalinclude:: /examples/openapi/plugins/swagger_ui_oauth.py + :language: python + +Customizing the OpenAPI UI +-------------------------- + +Style and behavior of the OpenAPI UI can be customized by overriding the default ``css_url`` and ``js_url`` attributes +on the render plugin class, for example: + +.. literalinclude:: /examples/openapi/plugins/scalar_customized.py + :language: python + +To learn more about customizing the ``Scalar`` UI, see the `Scalar documentation `_. + +CDN and offline file support +---------------------------- + +Each plugin supports ``js_url`` and ``css_url`` attributes, which can be used to specify a custom URL to the JavaScript. +These can be used to serve the JavaScript and CSS from a CDN, or to serve the files from a local directory. diff --git a/litestar/_openapi/plugin.py b/litestar/_openapi/plugin.py index 9bdbdecebd..347a5e8c5e 100644 --- a/litestar/_openapi/plugin.py +++ b/litestar/_openapi/plugin.py @@ -1,27 +1,62 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from litestar._openapi.datastructures import OpenAPIContext from litestar._openapi.path_item import create_path_item_for_route -from litestar.exceptions import ImproperlyConfiguredException +from litestar.constants import OPENAPI_JSON_HANDLER_NAME +from litestar.enums import MediaType +from litestar.exceptions import ImproperlyConfiguredException, NotFoundException +from litestar.handlers import get +from litestar.openapi.plugins import JsonRenderPlugin from litestar.plugins import InitPluginProtocol from litestar.plugins.base import ReceiveRoutePlugin +from litestar.response import Response +from litestar.router import Router from litestar.routes import HTTPRoute +from litestar.status_codes import HTTP_404_NOT_FOUND if TYPE_CHECKING: from litestar.app import Litestar from litestar.config.app import AppConfig + from litestar.connection import Request + from litestar.handlers import HTTPRouteHandler from litestar.openapi.config import OpenAPIConfig + from litestar.openapi.plugins import OpenAPIRenderPlugin from litestar.openapi.spec import OpenAPI from litestar.routes import BaseRoute +def handle_schema_path_not_found(path: str = "/") -> Response: + """Handler for returning HTML formatted errors from not-found schema paths. + + This preserves backward compatibility with the Controller-based OpenAPI implementation. + """ + if path.endswith((".json", ".yaml", ".yml")): + raise NotFoundException + + content = b""" + + + + 404 Not found + + + + +

Error 404

+ + + """ + return Response(content, media_type=MediaType.HTML, status_code=HTTP_404_NOT_FOUND) + + class OpenAPIPlugin(InitPluginProtocol, ReceiveRoutePlugin): __slots__ = ( "app", "included_routes", "_openapi_config", + "_openapi", "_openapi_schema", ) @@ -29,9 +64,10 @@ def __init__(self, app: Litestar) -> None: self.app = app self.included_routes: dict[str, HTTPRoute] = {} self._openapi_config: OpenAPIConfig | None = None - self._openapi_schema: OpenAPI | None = None + self._openapi: OpenAPI | None = None + self._openapi_schema: dict[str, object] | None = None - def _build_openapi_schema(self) -> OpenAPI: + def _build_openapi(self) -> OpenAPI: openapi_config = self.openapi_config if openapi_config.create_examples: @@ -49,14 +85,105 @@ def _build_openapi_schema(self) -> OpenAPI: return openapi def provide_openapi(self) -> OpenAPI: + if not self._openapi: + self._openapi = self._build_openapi() + return self._openapi + + def provide_openapi_schema(self) -> dict[str, Any]: if not self._openapi_schema: - self._openapi_schema = self._build_openapi_schema() + self._openapi_schema = self.provide_openapi().to_schema() return self._openapi_schema + def create_openapi_router(self) -> Router: + """Create a router for serving OpenAPI documentation and schema files. + + For each OpenAPI render plugin, a route is created to serve the plugin's + documentation site. + + A handler is added for serving a 404 page for any schema path that is not + configured by a plugin. + + A handler is added for serving the JSON OpenAPI schema file if it is not configured. + + For each plugin, the plugin's `receive_router` method is called with the router + instance. + + Returns: + The router. + """ + if (router := self.openapi_config.openapi_router) is None: + router = Router( + self.openapi_config.path or "/schema", + route_handlers=[], + include_in_schema=False, + dto=None, + return_dto=None, + ) + + root_configured = False + openapi_json_found = False + + def create_handler(plugin_: OpenAPIRenderPlugin) -> HTTPRouteHandler: + """Create a handler for serving the plugin's documentation site. + + If the plugin is the default plugin, a handler is created for the root path in addition + to the plugin's configured paths. + + If the plugin has a path for serving the OpenAPI schema file, the `openapi_json_found` + flag is set to `True`, so that we don't create a handler for serving the JSON schema file. + + Args: + plugin_: The plugin to create the handler for. + + Returns: + The handler. + """ + paths = list(plugin_.paths) + if plugin_ is self.openapi_config.default_plugin: + if not plugin_.has_path("/"): + paths.append("/") + nonlocal root_configured + root_configured = True + + handler_name = None + if plugin_.has_path("/openapi.json"): + nonlocal openapi_json_found + openapi_json_found = True + handler_name = OPENAPI_JSON_HANDLER_NAME + + @get(paths, media_type=plugin_.media_type, sync_to_thread=False, name=handler_name) + def _handler(request: Request) -> bytes: + return plugin_.render(request, self.provide_openapi_schema()) + + return _handler + + for plugin in self.openapi_config.render_plugins: + router.register(create_handler(plugin)) + + not_found_handler_paths = ["/{path:str}"] + if not root_configured: + not_found_handler_paths.append("/") + + not_found_handler = get(not_found_handler_paths, media_type=MediaType.HTML, sync_to_thread=False)( + handle_schema_path_not_found + ) + router.register(not_found_handler) + + if not openapi_json_found: + router.register(create_handler(JsonRenderPlugin())) + + for plugin in self.openapi_config.render_plugins: + plugin.receive_router(router) + + return router + def on_app_init(self, app_config: AppConfig) -> AppConfig: if app_config.openapi_config: self._openapi_config = app_config.openapi_config - app_config.route_handlers.append(self.openapi_config.openapi_controller) + if (controller := app_config.openapi_config.openapi_controller) is not None: + app_config.route_handlers.append(controller) + else: + app_config.route_handlers.append(self.create_openapi_router()) return app_config @property @@ -71,5 +198,5 @@ def receive_route(self, route: BaseRoute) -> None: if any(route_handler.resolve_include_in_schema() for route_handler, _ in route.route_handler_map.values()): # Force recompute the schema if a new route is added - self._openapi_schema = None + self._openapi = None self.included_routes[route.path] = route diff --git a/litestar/app.py b/litestar/app.py index 928617e322..995e90d6bd 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -860,7 +860,7 @@ def update_openapi_schema(self) -> None: Returns: None """ - self.plugins.get(OpenAPIPlugin)._build_openapi_schema() + self.plugins.get(OpenAPIPlugin)._build_openapi() def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None: """Emit an event to all attached listeners. diff --git a/litestar/cli/_utils.py b/litestar/cli/_utils.py index 84ebd3b1af..a7bc2b9405 100644 --- a/litestar/cli/_utils.py +++ b/litestar/cli/_utils.py @@ -392,7 +392,12 @@ def show_app_info(app: Litestar) -> None: # pragma: no cover openapi_enabled = _format_is_enabled(app.openapi_config) if app.openapi_config: - openapi_enabled += f" path=[yellow]{app.openapi_config.openapi_controller.path}" + path = ( + app.openapi_config.openapi_controller.path + if app.openapi_config.openapi_controller + else app.openapi_config.path or "/schema" + ) + openapi_enabled += f" path=[yellow]{path}" table.add_row("OpenAPI", openapi_enabled) table.add_row("Compression", app.compression_config.backend if app.compression_config else "[red]Disabled") @@ -561,5 +566,9 @@ def remove_routes_with_patterns( def remove_default_schema_routes( routes: list[HTTPRoute | ASGIRoute | WebSocketRoute], openapi_config: OpenAPIConfig ) -> list[HTTPRoute | ASGIRoute | WebSocketRoute]: - schema_path = openapi_config.openapi_controller.path + schema_path = ( + (openapi_config.path or "/schema") + if openapi_config.openapi_controller is None + else openapi_config.openapi_controller.path + ) return remove_routes_with_patterns(routes, (schema_path,)) diff --git a/litestar/constants.py b/litestar/constants.py index 930296c4be..2ba12f3694 100644 --- a/litestar/constants.py +++ b/litestar/constants.py @@ -1,6 +1,9 @@ +from __future__ import annotations + from dataclasses import MISSING from inspect import Signature from typing import Any, Final +from uuid import uuid4 from msgspec import UnsetType @@ -14,6 +17,7 @@ HTTP_RESPONSE_BODY: Final = "http.response.body" HTTP_RESPONSE_START: Final = "http.response.start" ONE_MEGABYTE: Final = 1024 * 1024 +OPENAPI_JSON_HANDLER_NAME: Final = f"{uuid4().hex}_litestar_openapi_json" OPENAPI_NOT_INITIALIZED: Final = "Litestar has not been instantiated with OpenAPIConfig" REDIRECT_STATUS_CODES: Final = {301, 302, 303, 307, 308} REDIRECT_ALLOWED_MEDIA_TYPES: Final = {MediaType.TEXT, MediaType.HTML, MediaType.JSON} @@ -23,7 +27,6 @@ WEBSOCKET_CLOSE: Final = "websocket.close" WEBSOCKET_DISCONNECT: Final = "websocket.disconnect" - # deprecated constants _SCOPE_STATE_CSRF_TOKEN_KEY = "csrf_token" # noqa: S105 # possible hardcoded password _SCOPE_STATE_DEPENDENCY_CACHE: Final = "dependency_cache" diff --git a/litestar/openapi/config.py b/litestar/openapi/config.py index c935693696..f1fd0acf12 100644 --- a/litestar/openapi/config.py +++ b/litestar/openapi/config.py @@ -2,10 +2,17 @@ from copy import deepcopy from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Final, Literal, Sequence from litestar._openapi.utils import default_operation_id_creator -from litestar.openapi.controller import OpenAPIController +from litestar.openapi.plugins import ( + JsonRenderPlugin, + RapidocRenderPlugin, + RedocRenderPlugin, + StoplightRenderPlugin, + SwaggerRenderPlugin, + YamlRenderPlugin, +) from litestar.openapi.spec import ( Components, Contact, @@ -19,13 +26,29 @@ Server, Tag, ) +from litestar.utils.deprecation import warn_deprecation from litestar.utils.path import normalize_path +if TYPE_CHECKING: + from litestar.openapi.controller import OpenAPIController + from litestar.openapi.plugins import OpenAPIRenderPlugin + from litestar.router import Router + from litestar.types.callable_types import OperationIDCreator + __all__ = ("OpenAPIConfig",) +_enabled_plugin_map = { + "elements": StoplightRenderPlugin, + "openapi.json": JsonRenderPlugin, + "openapi.yaml": YamlRenderPlugin, + "openapi.yml": YamlRenderPlugin, + "rapidoc": RapidocRenderPlugin, + "redoc": RedocRenderPlugin, + "swagger": SwaggerRenderPlugin, + "oauth2-redirect.html": None, +} -if TYPE_CHECKING: - from litestar.types.callable_types import OperationIDCreator +_DEFAULT_SCHEMA_SITE: Final = "redoc" @dataclass @@ -45,11 +68,6 @@ class OpenAPIConfig: """Generate examples using the polyfactory library.""" random_seed: int = 10 """The random seed used when creating the examples to ensure deterministic generation of examples.""" - openapi_controller: type[OpenAPIController] = field(default_factory=lambda: OpenAPIController) - """Controller for generating OpenAPI routes. - - Must be subclass of :class:`OpenAPIController `. - """ contact: Contact | None = field(default=None) """API contact information, should be an :class:`Contact ` instance.""" description: str | None = field(default=None) @@ -90,31 +108,123 @@ class OpenAPIConfig: :class:`Reference ` objects. """ - root_schema_site: Literal["redoc", "swagger", "elements", "rapidoc"] = "redoc" - """The static schema generator to use for the "root" path of `/schema/`.""" - enabled_endpoints: set[str] = field( - default_factory=lambda: { - "redoc", - "swagger", - "elements", - "rapidoc", - "openapi.json", - "openapi.yaml", - "openapi.yml", - "oauth2-redirect.html", - } - ) - """A set of the enabled documentation sites and schema download endpoints.""" operation_id_creator: OperationIDCreator = default_operation_id_creator """A callable that generates unique operation ids""" path: str | None = field(default=None) - """Base path for the OpenAPI documentation endpoints.""" + """Base path for the OpenAPI documentation endpoints. + + If no path is provided the default is ``/schema``. + + Ignored if ``openapi_router`` is provided. + """ + render_plugins: Sequence[OpenAPIRenderPlugin] = field(default=()) + """Plugins for rendering OpenAPI documentation UIs.""" + openapi_router: Router | None = None + """An optional router for serving OpenAPI documentation and schema files. + + If provided, ``path`` is ignored. + + This parameter is also ignored if the deprecated :class:`OpenAPIConfig <.openapi.OpenAPIConfig>` ``openapi_controller`` kwarg is provided. + + The ``openapi_router`` is not required, but it can be passed to customize the configuration of the router used to serve the documentation endpoints. For example, you can add middleware or guards to the router. + + Handlers to serve the OpenAPI schema and documentation sites are added to this router according + to the ``render_plugins`` attribute, so routes shouldn't be added that conflict with these. + """ + openapi_controller: type[OpenAPIController] | None = None + """Controller for generating OpenAPI routes. + + Must be subclass of :class:`OpenAPIController `. + """ + root_schema_site: Literal["redoc", "swagger", "elements", "rapidoc"] | None = None + """The static schema generator to use for the "root" path of ``/schema/``.""" + enabled_endpoints: set[str] | None = None + """A set of the enabled documentation sites and schema download endpoints.""" def __post_init__(self) -> None: + self._issue_deprecations() + + self.root_schema_site = self.root_schema_site or _DEFAULT_SCHEMA_SITE + + self.enabled_endpoints = ( + set(_enabled_plugin_map.keys()) if self.enabled_endpoints is None else self.enabled_endpoints + ) + if self.path: self.path = normalize_path(self.path) + + if self.path and self.openapi_controller is not None: self.openapi_controller = type("OpenAPIController", (self.openapi_controller,), {"path": self.path}) + self.default_plugin: OpenAPIRenderPlugin | None = None + if self.openapi_controller is None: + if not self.render_plugins: + self._plugin_backward_compatibility() + else: + # user is implicitly opted into the future plugin-based OpenAPI implementation + # behavior by explicitly providing a list of render plugins + for plugin in self.render_plugins: + if plugin.has_path("/"): + self.default_plugin = plugin + break + else: + self.default_plugin = self.render_plugins[0] + + def _issue_deprecations(self) -> None: + """Handle deprecated config options.""" + deprecated_in = "v2.8.0" + removed_in = "v3.0.0" + if self.openapi_controller is not None: + warn_deprecation( + deprecated_in, + "openapi_controller", + "attribute", + removal_in=removed_in, + alternative="render_plugins", + ) + + if self.root_schema_site is not None: + warn_deprecation( + deprecated_in, + "root_schema_site", + "attribute", + removal_in=removed_in, + alternative="render_plugins", + info="Any 'render_plugin' with path '/' or first 'render_plugin' in list will be served at the OpenAPI root.", + ) + + if self.enabled_endpoints is not None: + warn_deprecation( + deprecated_in, + "enabled_endpoints", + "attribute", + removal_in=removed_in, + alternative="render_plugins", + info="Configure a 'render_plugin' to enable an endpoint.", + ) + + def _plugin_backward_compatibility(self) -> None: + """Backward compatibility for the plugin-based OpenAPI implementation. + + This preserves backward compatibility with the Controller-based OpenAPI implementation. + + We add a plugin for each enabled endpoint and set the default plugin to the plugin + that has a path ending in the value of ``root_schema_site``. + """ + + def is_default_plugin(plugin_: OpenAPIRenderPlugin) -> bool: + """Return True if the plugin is the default plugin.""" + root_schema_site = self.root_schema_site or _DEFAULT_SCHEMA_SITE + return any(path.endswith(root_schema_site) for path in plugin_.paths) + + self.render_plugins = rps = [] + for key in self.enabled_endpoints or (): + if plugin_type := _enabled_plugin_map[key]: + plugin = plugin_type() + rps.append(plugin) + if is_default_plugin(plugin): + self.default_plugin = plugin + def to_openapi_schema(self) -> OpenAPI: """Return an ``OpenAPI`` instance from the values stored in ``self``. diff --git a/litestar/openapi/controller.py b/litestar/openapi/controller.py index ac03d4cd15..ec1c00845b 100644 --- a/litestar/openapi/controller.py +++ b/litestar/openapi/controller.py @@ -1,15 +1,16 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Final, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal from yaml import dump as dump_yaml -from litestar.constants import OPENAPI_NOT_INITIALIZED +from litestar.constants import OPENAPI_JSON_HANDLER_NAME, OPENAPI_NOT_INITIALIZED from litestar.controller import Controller from litestar.enums import MediaType, OpenAPIMediaType from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers import get +from litestar.openapi.config import _DEFAULT_SCHEMA_SITE from litestar.response.base import ASGIResponse from litestar.serialization import encode_json from litestar.serialization.msgspec_hooks import decode_json @@ -22,8 +23,6 @@ from litestar.connection.request import Request from litestar.openapi.spec.open_api import OpenAPI -_OPENAPI_JSON_ROUTER_NAME: Final = "__litestar_openapi_json" - class OpenAPIController(Controller): """Controller for OpenAPI endpoints.""" @@ -118,11 +117,13 @@ def should_serve_endpoint(self, request: Request[Any, Any, Any]) -> bool: root_path = set(filter(None, self.path.split("/"))) config = request.app.openapi_config + enabled_endpoints = config.enabled_endpoints or set() + root_schema_site = config.root_schema_site or _DEFAULT_SCHEMA_SITE - if request_path == root_path and config.root_schema_site in config.enabled_endpoints: + if request_path == root_path and root_schema_site in enabled_endpoints: return True - return bool(request_path & config.enabled_endpoints) + return bool(request_path & enabled_endpoints) @property def favicon(self) -> str: @@ -178,7 +179,7 @@ def retrieve_schema_yaml(self, request: Request[Any, Any, Any]) -> ASGIResponse: media_type=OpenAPIMediaType.OPENAPI_JSON, include_in_schema=False, sync_to_thread=False, - name=_OPENAPI_JSON_ROUTER_NAME, + name=OPENAPI_JSON_HANDLER_NAME, ) def retrieve_schema_json(self, request: Request[Any, Any, Any]) -> ASGIResponse: """Return the OpenAPI schema as JSON with an ``application/vnd.oai.openapi+json`` Content-Type header. @@ -218,7 +219,7 @@ def root(self, request: Request[Any, Any, Any]) -> ASGIResponse: if not config: # pragma: no cover raise ImproperlyConfiguredException(OPENAPI_NOT_INITIALIZED) - render_method = self.render_methods_map[config.root_schema_site] + render_method = self.render_methods_map[config.root_schema_site or _DEFAULT_SCHEMA_SITE] if self.should_serve_endpoint(request): return ASGIResponse(body=render_method(request), media_type=MediaType.HTML) @@ -468,7 +469,7 @@ def render_stoplight_elements(self, request: Request[Any, Any, Any]) -> bytes: body = f""" @@ -499,7 +500,7 @@ def render_rapidoc(self, request: Request[Any, Any, Any]) -> bytes: # pragma: n body = f""" - + """ diff --git a/litestar/openapi/plugins.py b/litestar/openapi/plugins.py new file mode 100644 index 0000000000..a1d6c910d0 --- /dev/null +++ b/litestar/openapi/plugins.py @@ -0,0 +1,665 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Sequence + +import msgspec +import yaml + +from litestar.constants import OPENAPI_JSON_HANDLER_NAME +from litestar.enums import MediaType, OpenAPIMediaType +from litestar.handlers import get +from litestar.serialization import encode_json, get_serializer + +if TYPE_CHECKING: + from litestar.connection import Request + from litestar.router import Router + +__all__ = ( + "OpenAPIRenderPlugin", + "RapidocRenderPlugin", + "RedocRenderPlugin", + "ScalarRenderPlugin", + "StoplightRenderPlugin", + "SwaggerRenderPlugin", + "YamlRenderPlugin", +) + +_favicon_url = "https://cdn.jsdelivr.net/gh/litestar-org/branding@main/assets/Branding%20-%20PNG%20-%20Transparent/Badge%20-%20Blue%20and%20Yellow.png" +_default_favicon = f"" +_default_style = "" + + +class OpenAPIRenderPlugin(ABC): + """Base class for OpenAPI UI render plugins.""" + + paths: list[str] + + def __init__( + self, + *, + path: str | Sequence[str], + media_type: MediaType | OpenAPIMediaType = MediaType.HTML, + favicon: str = _default_favicon, + style: str = _default_style, + ) -> None: + """Initialize the OpenAPI UI render plugin. + + Args: + path: Path to serve the OpenAPI UI at. + media_type: Media type for the handler. + favicon: Html tag for the favicon. + style: Base styling of the html body. + """ + self.paths = [path] if isinstance(path, str) else list(path) + self.media_type = media_type + self.favicon = favicon + self.style = style + + @staticmethod + def render_json(request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render the OpenAPI schema as JSON. + + Args: + request: The request that triggered the render. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + The rendered JSON. + """ + return encode_json(openapi_schema, serializer=get_serializer(request.route_handler.resolve_type_encoders())) + + @abstractmethod + def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render the OpenAPI UI. + + Args: + request: The request that triggered the render. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + The rendered HTML. + """ + raise NotImplementedError + + @staticmethod + def get_openapi_json_route(request: Request) -> str: + """Get the route for the OpenAPI JSON schema. + + Returns: + The route for the OpenAPI JSON schema. + """ + return request.app.route_reverse(OPENAPI_JSON_HANDLER_NAME) + + def receive_router(self, router: Router) -> None: + """Receive the router that serves the OpenAPI UI. + + Can be used by plugins to additionally configure the router, e.g. to add + additional routes. + + Args: + router: The router that serves the OpenAPI UI. + """ + return + + def has_path(self, path: str) -> bool: + """Check if the plugin has a path. + + Args: + path: The path to check. + + Returns: + True if the plugin has the path, False otherwise. + """ + return path in self.paths + + +class JsonRenderPlugin(OpenAPIRenderPlugin): + """Render the OpenAPI schema as JSON.""" + + def __init__( + self, + *, + path: str | Sequence[str] = "/openapi.json", + media_type: MediaType | OpenAPIMediaType = OpenAPIMediaType.OPENAPI_JSON, + **kwargs: Any, + ) -> None: + """Initialize the OpenAPI UI render plugin. + + Args: + path: Path to serve the OpenAPI UI at. + media_type: Media type for the handler. + **kwargs: Additional arguments to pass to the base class. + + """ + super().__init__(path=path, media_type=media_type, **kwargs) + + def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render an OpenAPI schema as JSON. + + Args: + request: The request. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + The rendered OpenAPI schema as JSON. + """ + return self.render_json(request, openapi_schema) + + +class YamlRenderPlugin(OpenAPIRenderPlugin): + """Render an OpenAPI schema as YAML.""" + + def __init__( + self, + *, + path: str | Sequence[str] = ("/openapi.yaml", "/openapi.yml"), + media_type: MediaType | OpenAPIMediaType = OpenAPIMediaType.OPENAPI_YAML, + **kwargs: Any, + ) -> None: + """Initialize the OpenAPI UI render plugin. + + Args: + path: Path to serve the OpenAPI UI at. + media_type: Media type for the handler. + **kwargs: Additional arguments to pass to the base class. + """ + super().__init__(path=path, media_type=media_type, **kwargs) + + def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render an OpenAPI schema as YAML. + + Args: + request: The request. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + The rendered OpenAPI schema as YAML. + """ + # using msgspec.to_builtins() ensures that any examples generated by polyfactory that have the + # UNSET value (possible if the examples are being generated for a partial DTO model which makes + # every type a union with UNSET) are stripped out. + openapi_schema = msgspec.to_builtins( + openapi_schema, enc_hook=get_serializer(request.route_handler.resolve_type_encoders()) + ) + return yaml.dump(openapi_schema, default_flow_style=False).encode("utf-8") + + +class RapidocRenderPlugin(OpenAPIRenderPlugin): + """Render an OpenAPI schema using Rapidoc.""" + + def __init__( + self, + *, + version: str = "9.3.4", + js_url: str | None = None, + path: str | Sequence[str] = "/rapidoc", + **kwargs: Any, + ) -> None: + """Initialize the OpenAPI UI render plugin. + + Args: + version: Rapidoc version to download from the CDN. If js_url is provided, this is ignored. + js_url: Download url for the RapiDoc JS bundle. If not provided, the version will be used to construct the + url. + path: Path to serve the OpenAPI UI at. + **kwargs: Additional arguments to pass to the base class. + """ + self.js_url = js_url or f"https://unpkg.com/rapidoc@{version}/dist/rapidoc-min.js" + super().__init__(path=path, **kwargs) + + def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render an HTML page for Rapidoc. + + .. note:: Override this method to customize the template. + + Args: + request: The request. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + A rendered html string. + """ + + head = f""" + + {openapi_schema["info"]["title"]} + {self.favicon} + + + + {self.style} + + """ + + body = f""" + + + + """ + + return f""" + + + {head} + {body} + + """.encode() + + +class RedocRenderPlugin(OpenAPIRenderPlugin): + """Render an OpenAPI schema using Redoc.""" + + def __init__( + self, + *, + version: str = "next", + js_url: str | None = None, + google_fonts: bool = True, + path: str | Sequence[str] = "/redoc", + **kwargs: Any, + ) -> None: + """Initialize the OpenAPI UI render plugin. + + Args: + version: Redoc version to download from the CDN. If js_url is provided, this is ignored. + js_url: Download url for the Redoc JS bundle. If not provided, the version will be used to construct the url. + google_fonts: Download google fonts via CDN. Should be set to False when not using a CDN. + path: Path to serve the OpenAPI UI at. + **kwargs: Additional arguments to pass to the base class. + """ + self.js_url = js_url or f"https://cdn.jsdelivr.net/npm/redoc@{version}/bundles/redoc.standalone.js" + self.google_fonts = google_fonts + super().__init__(path=path, **kwargs) + + def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render an HTML page for Redoc. + + .. note:: override this method to customize the template. + + Args: + request: The request. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + A rendered html string. + """ + + head = f""" + + {openapi_schema["info"]["title"]} + {self.favicon} + + + """ + + if self.google_fonts: + head += """ + + """ + + head += f""" + + {self.style} + + """ + + body = b"".join( + [ + b"
", + ] + ) + + return b"".join( + [ + b"", + head.encode(), + body, + b"", + ] + ) + + +class ScalarRenderPlugin(OpenAPIRenderPlugin): + """Plugin to render an OpenAPI schema using Scalar. + + .. versionadded:: 2.8.0 + """ + + _default_css_url = "https://cdn.jsdelivr.net/gh/litestar-org/branding@main/assets/openapi/scalar.css" + + def __init__( + self, + *, + version: str = "1.19.5", + js_url: str | None = None, + css_url: str | None = None, + path: str | Sequence[str] = "/scalar", + **kwargs: Any, + ) -> None: + """Initialize the Scalar OpenAPI UI render plugin. + + Args: + version: Scalar version to download from the CDN. + If js_url is provided, this is ignored. + js_url: Download url for the Scalar JS bundle. + If not provided, the version will be used to construct the url. + css_url: Download url for the Scalar CSS bundle. + If not provided, the Litestar-provided CSS will be used. + path: Path to serve the OpenAPI UI at. + **kwargs: Additional arguments to pass to the base class. + """ + self.js_url = js_url or f"https://cdn.jsdelivr.net/npm/@scalar/api-reference@{version}" + self.css_url = css_url or self._default_css_url + super().__init__(path=path, **kwargs) + + def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render an HTMl page for Scalar. + + .. note:: Override this method to customize the template. + + Args: + request: The request. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + A rendered html string. + """ + head = f""" + + {openapi_schema["info"]["title"]} + {self.style} + + + {self.favicon} + + + """ + + body = f""" + + + + """ + + return f""" + + + {head} + {body} + + """.encode() + + +class StoplightRenderPlugin(OpenAPIRenderPlugin): + """Render an OpenAPI schema using StopLight Elements.""" + + def __init__( + self, + *, + version: str = "7.7.18", + js_url: str | None = None, + css_url: str | None = None, + path: str | Sequence[str] = "/elements", + **kwargs: Any, + ) -> None: + """Initialize the OpenAPI UI render plugin. + + Args: + version: StopLight Elements version to download from the CDN. If js_url is provided, this is ignored. + js_url: Download url for the StopLight Elements JS bundle. If not provided, the version will be used to + construct the url. + css_url: Download url for the StopLight Elements CSS bundle. If not provided, the version will be used to + construct the url. + path: Path to serve the OpenAPI UI at. + **kwargs: Additional arguments to pass to the base class. + """ + self.js_url = js_url or f"https://unpkg.com/@stoplight/elements@{version}/web-components.min.js" + self.css_url = css_url or f"https://unpkg.com/@stoplight/elements@{version}/styles.min.css" + super().__init__(path=path, **kwargs) + + def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render an HTML page for StopLight Elements. + + .. note:: Override this method to customize the template. + + Args: + request: The request. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + A rendered html string. + """ + head = f""" + + {openapi_schema["info"]["title"]} + {self.favicon} + + + + + {self.style} + + """ + + body = f""" + + + + """ + + return f""" + + + {head} + {body} + + """.encode() + + +class SwaggerRenderPlugin(OpenAPIRenderPlugin): + """Render an OpenAPI schema using Swagger-UI.""" + + def __init__( + self, + version: str = "5.1.3", + js_url: str | None = None, + css_url: str | None = None, + standalone_preset_js_url: str | None = None, + init_oauth: dict[str, Any] | bytes | None = None, + path: str | Sequence[str] = "/swagger", + **kwargs: Any, + ) -> None: + """Initialize the OpenAPI UI render plugin. + + Args: + version: SwaggerUI version to download from the CDN. If js_url is provided, this is ignored. + js_url: Download url for the Swagger UI JS bundle. If not provided, the version will be used to construct + the url. + css_url: Download url for the Swagger UI CSS bundle. If not provided, the version will be used to construct + the url. + standalone_preset_js_url: Download url for the Swagger Standalone Preset JS bundle. If not provided, the + version will be used to construct the url. + init_oauth: JSON to initialize Swagger UI OAuth2 by calling the ``initOAuth`` method. + Refer to the following URL for details: + `Swagger-UI `_. + path: Path to serve the OpenAPI UI at. + **kwargs: Additional arguments to pass to the base class. + """ + self.js_url = js_url or f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{version}/swagger-ui-bundle.js" + self.css_url = css_url or f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{version}/swagger-ui.css" + self.standalone_preset_js_url = ( + standalone_preset_js_url + or f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{version}/swagger-ui-standalone-preset.js" + ) + self.init_oauth = init_oauth or {} + super().__init__(path=path, **kwargs) + + def render(self, request: Request, openapi_schema: dict[str, Any]) -> bytes: + """Render an HTML page for Swagger-UI. + + Notes: + - override this method to customize the template. + + Args: + request: The request. + openapi_schema: The OpenAPI schema as a dictionary. + + Returns: + A rendered html string. + """ + + head = f""" + + {openapi_schema["info"]["title"]} + {self.favicon} + + + + + + {self.style} + + """ + + body = b"".join( + [ + b""" + +
+ + + """, + ] + ) + + return b"".join([b"", head.encode(), body, b""]) + + def receive_router(self, router: Router) -> None: + """Receive the router that serves the OpenAPI UI. + + Adds a route to serve the OAuth2 redirect page. + + Args: + router: The router that serves the OpenAPI UI. + """ + router.register( + get("/oauth2-redirect.html", media_type=MediaType.HTML, sync_to_thread=False)(self.render_oauth2_redirect), + ) + + @staticmethod + def render_oauth2_redirect() -> bytes: + """Render an HTML oauth2-redirect.html page for Swagger-UI. + + .. note:: Override this method to customize the template. + + Returns: + A rendered html string. + """ + return rb""" + + + Swagger UI: OAuth2 Redirect + + + + + """ diff --git a/test_apps/openapi_test_app/main.py b/test_apps/openapi_test_app/main.py index b4cacfe1a7..b78bd2faaa 100644 --- a/test_apps/openapi_test_app/main.py +++ b/test_apps/openapi_test_app/main.py @@ -1,16 +1,29 @@ -from typing import Dict +from __future__ import annotations + +import msgspec from litestar import Litestar, get +from litestar.openapi.config import OpenAPIConfig +from litestar.openapi.plugins import ScalarRenderPlugin from tests.unit.test_openapi.conftest import create_person_controller, create_pet_controller -@get("/") -async def greet() -> Dict[str, str]: - return {"hello": "world"} +class Model(msgspec.Struct): + hello: str = "world" + + +@get("/", sync_to_thread=False) +def greet() -> Model: + return Model(hello="world") app = Litestar( route_handlers=[greet, create_person_controller(), create_pet_controller()], + openapi_config=OpenAPIConfig( + title="whatever", + version="0.0.1", + render_plugins=[ScalarRenderPlugin()], + ), ) diff --git a/tests/conftest.py b/tests/conftest.py index e8007dea37..633d3e6efe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ from litestar.middleware.session.base import BaseSessionBackend from litestar.middleware.session.client_side import ClientSideSessionBackend, CookieBackendConfig from litestar.middleware.session.server_side import ServerSideSessionBackend, ServerSideSessionConfig +from litestar.openapi.config import OpenAPIConfig from litestar.stores.base import Store from litestar.stores.file import FileStore from litestar.stores.memory import MemoryStore @@ -315,3 +316,8 @@ async def redis_client(docker_ip: str, redis_service: None) -> AsyncGenerator[As await client.aclose() # type: ignore[attr-defined] except RuntimeError: pass + + +@pytest.fixture(autouse=True) +def _patch_openapi_config(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("litestar.app.DEFAULT_OPENAPI_CONFIG", OpenAPIConfig(title="Litestar API", version="1.0.0")) diff --git a/tests/examples/test_openapi/__init__.py b/tests/examples/test_openapi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/examples/test_openapi.py b/tests/examples/test_openapi/test_openapi.py similarity index 86% rename from tests/examples/test_openapi.py rename to tests/examples/test_openapi/test_openapi.py index e83dd1f41d..c5f3a7e7ea 100644 --- a/tests/examples/test_openapi.py +++ b/tests/examples/test_openapi/test_openapi.py @@ -36,3 +36,11 @@ def test_schema_generation() -> None: } }, } + + +def test_customize_path() -> None: + from docs.examples.openapi.customize_path import app + + with TestClient(app=app) as client: + resp = client.get("/docs/openapi.json") + assert resp.status_code == 200 diff --git a/tests/examples/test_openapi/test_plugins.py b/tests/examples/test_openapi/test_plugins.py new file mode 100644 index 0000000000..8c9862000e --- /dev/null +++ b/tests/examples/test_openapi/test_plugins.py @@ -0,0 +1,118 @@ +import pytest + +from litestar.openapi.config import OpenAPIConfig +from litestar.testing import TestClient, create_test_client + + +def test_scalar_simple() -> None: + from docs.examples.openapi.plugins.scalar_simple import app + + with TestClient(app=app) as client: + resp = client.get("/schema/scalar") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "Litestar Example" in resp.text + + +def test_rapidoc_simple() -> None: + from docs.examples.openapi.plugins.rapidoc_simple import app + + with TestClient(app=app) as client: + resp = client.get("/schema/rapidoc") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "Litestar Example" in resp.text + + +def test_redoc_simple() -> None: + from docs.examples.openapi.plugins.redoc_simple import app + + with TestClient(app=app) as client: + resp = client.get("/schema/redoc") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "Litestar Example" in resp.text + + +def test_stoplights_simple() -> None: + from docs.examples.openapi.plugins.stoplight_simple import app + + with TestClient(app=app) as client: + resp = client.get("/schema/elements") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "Litestar Example" in resp.text + + +def test_swagger_ui_simple() -> None: + from docs.examples.openapi.plugins.swagger_ui_simple import app + + with TestClient(app=app) as client: + resp = client.get("/schema/swagger") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "Litestar Example" in resp.text + + +@pytest.mark.parametrize("path", ["/schema/openapi.yml", "/schema/openapi.yaml"]) +def test_yaml_simple(path: str) -> None: + from docs.examples.openapi.plugins.yaml_simple import app + + with TestClient(app=app) as client: + resp = client.get(path) + assert resp.status_code == 200 + assert resp.headers["content-type"] == "application/vnd.oai.openapi" + assert "Litestar Example" in resp.text + + +def test_serving_multiple_uis() -> None: + from docs.examples.openapi.plugins.serving_multiple_uis import app + + with TestClient(app=app) as client: + resp = client.get("/schema/rapidoc") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "Litestar Example" in resp.text + + resp = client.get("/schema/swagger") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "Litestar Example" in resp.text + + +def test_custom_plugin() -> None: + from docs.examples.openapi.plugins.custom_plugin import ScalarRenderPlugin + + openapi_config = OpenAPIConfig( + title="My API", + description="This is the description of my API", + version="0.1.0", + render_plugins=[ScalarRenderPlugin()], + ) + + with create_test_client(route_handlers=[], openapi_config=openapi_config) as client: + resp = client.get("/schema/scalar") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "My API" in resp.text + + +def test_receive_router() -> None: + from docs.examples.openapi.plugins.receive_router import MyOpenAPIPlugin + + openapi_config = OpenAPIConfig( + title="My API", + description="This is the description of my API", + version="0.1.0", + render_plugins=[MyOpenAPIPlugin(path="/custom")], + ) + + with create_test_client(route_handlers=[], openapi_config=openapi_config) as client: + resp = client.get("/schema/custom") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + assert "My UI of Choice" in resp.text + resp = client.get("/schema/something") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "text/plain; charset=utf-8" + assert "Something" in resp.text diff --git a/tests/unit/test_cli/__init__.py b/tests/unit/test_cli/__init__.py index e1cf691155..7a64e3e031 100644 --- a/tests/unit/test_cli/__init__.py +++ b/tests/unit/test_cli/__init__.py @@ -78,12 +78,9 @@ def create_app() -> Litestar: """ APP_FILE_CONTENT_ROUTES_EXAMPLE = """ from litestar import Litestar, get -from litestar.openapi import OpenAPIConfig, OpenAPIController +from litestar.openapi import OpenAPIConfig from typing import Dict -class CustomOpenAPIController(OpenAPIController): - path = "/api-docs" - @get("/") def hello_world() -> Dict[str, str]: @@ -105,7 +102,8 @@ def long_api() -> Dict[str, str]: openapi_config=OpenAPIConfig( title="test_app", version="0", - openapi_controller=CustomOpenAPIController), + path="/api-docs", + ), route_handlers=[hello_world, foo, long_api] ) diff --git a/tests/unit/test_cli/test_core_commands.py b/tests/unit/test_cli/test_core_commands.py index 4292e25b41..82c7727d93 100644 --- a/tests/unit/test_cli/test_core_commands.py +++ b/tests/unit/test_cli/test_core_commands.py @@ -404,21 +404,21 @@ def test_run_command_with_server_lifespan_plugin( @pytest.mark.parametrize( "app_content, schema_enabled, exclude_pattern_list, expected_result_routes_count", [ - pytest.param(APP_FILE_CONTENT_ROUTES_EXAMPLE, False, (), 3, id="schema-enabled_no-exclude"), + pytest.param(APP_FILE_CONTENT_ROUTES_EXAMPLE, False, (), 3, id="schema-disabled_no-exclude"), pytest.param( APP_FILE_CONTENT_ROUTES_EXAMPLE, False, ("/foo", "/destroy/.*", "/java", "/haskell"), 2, - id="schema-enabled_exclude", + id="schema-disabled_exclude", ), - pytest.param(APP_FILE_CONTENT_ROUTES_EXAMPLE, True, (), 12, id="schema-disabled_no-exclude"), + pytest.param(APP_FILE_CONTENT_ROUTES_EXAMPLE, True, (), 13, id="schema-enabled_no-exclude"), pytest.param( APP_FILE_CONTENT_ROUTES_EXAMPLE, True, ("/foo", "/destroy/.*", "/java", "/haskell"), - 11, - id="schema-disabled_exclude", + 12, + id="schema-enabled_exclude", ), ], ) diff --git a/tests/unit/test_deprecations.py b/tests/unit/test_deprecations.py index ca3bb54779..acc1beaa75 100644 --- a/tests/unit/test_deprecations.py +++ b/tests/unit/test_deprecations.py @@ -146,3 +146,25 @@ def test_is_sync_or_async_generator_deprecation() -> None: with pytest.warns(DeprecationWarning): from litestar.utils import is_sync_or_async_generator as _ # noqa: F401 + + +def test_openapi_config_openapi_controller_deprecation() -> None: + from litestar.openapi.config import OpenAPIConfig + from litestar.openapi.controller import OpenAPIController + + with pytest.warns(DeprecationWarning): + OpenAPIConfig(title="API", version="1.0", openapi_controller=OpenAPIController) + + +def test_openapi_config_root_schema_site_deprecation() -> None: + from litestar.openapi.config import OpenAPIConfig + + with pytest.warns(DeprecationWarning): + OpenAPIConfig(title="API", version="1.0", root_schema_site="redoc") + + +def test_openapi_config_enabled_endpoints_deprecation() -> None: + from litestar.openapi.config import OpenAPIConfig + + with pytest.warns(DeprecationWarning): + OpenAPIConfig(title="API", version="1.0", enabled_endpoints={"redoc"}) diff --git a/tests/unit/test_openapi/conftest.py b/tests/unit/test_openapi/conftest.py index 966dd0bcb0..20dfeb6c7a 100644 --- a/tests/unit/test_openapi/conftest.py +++ b/tests/unit/test_openapi/conftest.py @@ -6,6 +6,7 @@ from litestar import Controller, MediaType, delete, get, patch, post, put from litestar.datastructures import ResponseHeader, State from litestar.dto import DataclassDTO, DTOConfig, DTOData +from litestar.openapi.controller import OpenAPIController from litestar.openapi.spec.example import Example from litestar.params import Parameter from tests.models import DataclassPerson, DataclassPersonFactory, DataclassPet @@ -20,7 +21,7 @@ def create_person_controller() -> Type[Controller]: class PersonController(Controller): path = "/{service_id:int}/person" - @get() + @get(sync_to_thread=False) def get_persons( self, # expected to be ignored @@ -53,41 +54,46 @@ def get_persons( ) -> List[DataclassPerson]: return [] - @post(media_type=MediaType.TEXT) + @post(media_type=MediaType.TEXT, sync_to_thread=False) def create_person( self, data: DataclassPerson, secret_header: str = Parameter(header="secret") ) -> DataclassPerson: return data - @post(path="/bulk", dto=PartialDataclassPersonDTO) + @post(path="/bulk", dto=PartialDataclassPersonDTO, sync_to_thread=False) def bulk_create_person( self, data: DTOData[List[DataclassPerson]], secret_header: str = Parameter(header="secret") ) -> List[DataclassPerson]: return [] - @put(path="/bulk") + @put(path="/bulk", sync_to_thread=False) def bulk_update_person( self, data: List[DataclassPerson], secret_header: str = Parameter(header="secret") ) -> List[DataclassPerson]: return [] - @patch(path="/bulk", dto=PartialDataclassPersonDTO) + @patch(path="/bulk", dto=PartialDataclassPersonDTO, sync_to_thread=False) def bulk_partial_update_person( self, data: DTOData[List[DataclassPerson]], secret_header: str = Parameter(header="secret") ) -> List[DataclassPerson]: return [] - @get(path="/{person_id:str}") + @get(path="/{person_id:str}", sync_to_thread=False) def get_person_by_id(self, person_id: str) -> DataclassPerson: """Description in docstring.""" return DataclassPersonFactory.build(id=person_id) - @patch(path="/{person_id:str}", description="Description in decorator", dto=PartialDataclassPersonDTO) + @patch( + path="/{person_id:str}", + description="Description in decorator", + dto=PartialDataclassPersonDTO, + sync_to_thread=False, + ) def partial_update_person(self, person_id: str, data: DTOData[DataclassPerson]) -> DataclassPerson: """Description in docstring.""" return DataclassPersonFactory.build(id=person_id) - @put(path="/{person_id:str}") + @put(path="/{person_id:str}", sync_to_thread=False) def update_person(self, person_id: str, data: DataclassPerson) -> DataclassPerson: """Multiline docstring example. @@ -95,11 +101,11 @@ def update_person(self, person_id: str, data: DataclassPerson) -> DataclassPerso """ return data - @delete(path="/{person_id:str}") + @delete(path="/{person_id:str}", sync_to_thread=False) def delete_person(self, person_id: str) -> None: return None - @get(path="/dataclass") + @get(path="/dataclass", sync_to_thread=False) def get_person_dataclass(self) -> DataclassPerson: return DataclassPerson( first_name="Moishe", last_name="zuchmir", id="1", optional=None, complex={}, pets=None @@ -112,12 +118,15 @@ def create_pet_controller() -> Type[Controller]: class PetController(Controller): path = "/pet" - @get() + @get(sync_to_thread=False) def pets(self) -> List[DataclassPet]: return [] @get( - path="/owner-or-pet", response_headers=[ResponseHeader(name="x-my-tag", value="123")], raises=[PetException] + path="/owner-or-pet", + response_headers=[ResponseHeader(name="x-my-tag", value="123")], + raises=[PetException], + sync_to_thread=False, ) def get_pets_or_owners(self) -> List[Union[DataclassPerson, DataclassPet]]: return [] @@ -135,3 +144,8 @@ def person_controller() -> Type[Controller]: @pytest.mark.usefixtures("disable_warn_implicit_sync_to_thread") def pet_controller() -> Type[Controller]: return create_pet_controller() + + +@pytest.fixture(params=[OpenAPIController, None]) +def openapi_controller(request: pytest.FixtureRequest) -> Optional[Type[OpenAPIController]]: + return request.param # type: ignore[no-any-return] diff --git a/tests/unit/test_openapi/test_config.py b/tests/unit/test_openapi/test_config.py index 15dfff5dd4..712edf3e59 100644 --- a/tests/unit/test_openapi/test_config.py +++ b/tests/unit/test_openapi/test_config.py @@ -1,14 +1,17 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List, Type import pytest from litestar import Litestar, get from litestar.exceptions import ImproperlyConfiguredException from litestar.openapi.config import OpenAPIConfig +from litestar.openapi.controller import OpenAPIController +from litestar.openapi.plugins import RedocRenderPlugin, SwaggerRenderPlugin from litestar.openapi.spec import Components, Example, OpenAPIHeader, OpenAPIType, Schema if TYPE_CHECKING: from litestar.handlers.http_handlers import HTTPRouteHandler + from litestar.openapi.plugins import OpenAPIRenderPlugin def test_merged_components_correct() -> None: @@ -83,14 +86,36 @@ def handler_2() -> None: def test_allows_customization_of_path() -> None: app = Litestar( - openapi_config=OpenAPIConfig(title="my title", version="1.0.0", path="/custom_schema_path"), + openapi_config=OpenAPIConfig( + title="my title", version="1.0.0", openapi_controller=OpenAPIController, path="/custom_schema_path" + ), ) assert app.openapi_config assert app.openapi_config.path == "/custom_schema_path" + assert app.openapi_config.openapi_controller is not None assert app.openapi_config.openapi_controller.path == "/custom_schema_path" def test_raises_exception_when_no_config_in_place() -> None: with pytest.raises(ImproperlyConfiguredException): Litestar(route_handlers=[], openapi_config=None).update_openapi_schema() + + +@pytest.mark.parametrize( + ("plugins", "exp"), + [ + ((), RedocRenderPlugin), + ([RedocRenderPlugin()], RedocRenderPlugin), + ([SwaggerRenderPlugin(), RedocRenderPlugin()], SwaggerRenderPlugin), + ([RedocRenderPlugin(), SwaggerRenderPlugin(path="/")], SwaggerRenderPlugin), + ], +) +def test_default_plugin(plugins: "List[OpenAPIRenderPlugin]", exp: "Type[OpenAPIRenderPlugin]") -> None: + config = OpenAPIConfig(title="my title", version="1.0.0", render_plugins=plugins) + assert isinstance(config.default_plugin, exp) + + +def test_default_plugin_legacy() -> None: + config = OpenAPIConfig(title="my title", version="1.0.0", openapi_controller=OpenAPIController) + assert config.default_plugin is None diff --git a/tests/unit/test_openapi/test_controller.py b/tests/unit/test_openapi/test_controller.py deleted file mode 100644 index aadf950dcd..0000000000 --- a/tests/unit/test_openapi/test_controller.py +++ /dev/null @@ -1,276 +0,0 @@ -from typing import List, Type - -import pytest - -from litestar import Controller -from litestar.enums import MediaType -from litestar.openapi.config import OpenAPIConfig -from litestar.openapi.controller import OpenAPIController -from litestar.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND -from litestar.testing import create_test_client - -root_paths: List[str] = ["", "/part1", "/part1/part2"] - - -def test_default_redoc_cdn_urls(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client([person_controller, pet_controller]) as client: - response = client.get("/schema/redoc") - default_redoc_version = "next" - default_redoc_js_bundle = ( - f"https://cdn.jsdelivr.net/npm/redoc@{default_redoc_version}/bundles/redoc.standalone.js" - ) - assert client.app.openapi_config is not None - assert client.app.openapi_config.openapi_controller.redoc_js_url == default_redoc_js_bundle - assert default_redoc_js_bundle in response.text - - -def test_default_swagger_ui_cdn_urls(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client([person_controller, pet_controller]) as client: - response = client.get("/schema/swagger") - default_swagger_bundles = [ - f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{OpenAPIController.swagger_ui_version}/swagger-ui.css", - f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{OpenAPIController.swagger_ui_version}/swagger-ui-bundle.js", - f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{OpenAPIController.swagger_ui_version}/swagger-ui-standalone-preset.js", - ] - assert client.app.openapi_config is not None - assert client.app.openapi_config.openapi_controller.swagger_css_url in default_swagger_bundles - assert client.app.openapi_config.openapi_controller.swagger_ui_bundle_js_url in default_swagger_bundles - assert ( - client.app.openapi_config.openapi_controller.swagger_ui_standalone_preset_js_url in default_swagger_bundles - ) - assert all(cdn_url in response.text for cdn_url in default_swagger_bundles) - - -def test_default_stoplight_elements_cdn_urls( - person_controller: Type[Controller], pet_controller: Type[Controller] -) -> None: - with create_test_client([person_controller, pet_controller]) as client: - response = client.get("/schema/elements") - default_stoplight_elements_bundles = [ - f"https://unpkg.com/@stoplight/elements@{OpenAPIController.stoplight_elements_version}/styles.min.css", - f"https://unpkg.com/@stoplight/elements@{OpenAPIController.stoplight_elements_version}/web-components.min.js", - ] - assert client.app.openapi_config is not None - assert ( - client.app.openapi_config.openapi_controller.stoplight_elements_css_url - in default_stoplight_elements_bundles - ) - assert ( - client.app.openapi_config.openapi_controller.stoplight_elements_js_url in default_stoplight_elements_bundles - ) - assert all(cdn_url in response.text for cdn_url in default_stoplight_elements_bundles) - - -def test_default_rapidoc_cdn_urls(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client([person_controller, pet_controller]) as client: - response = client.get("/schema/rapidoc") - default_rapidoc_bundles = [f"https://unpkg.com/rapidoc@{OpenAPIController.rapidoc_version}/dist/rapidoc-min.js"] - assert client.app.openapi_config is not None - assert client.app.openapi_config.openapi_controller.rapidoc_js_url in default_rapidoc_bundles - assert all(cdn_url in response.text for cdn_url in default_rapidoc_bundles) - - -def test_redoc_with_google_fonts(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client([person_controller, pet_controller]) as client: - response = client.get("/schema/redoc") - google_font_cdn = "https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" - assert client.app.openapi_config is not None - assert client.app.openapi_config.openapi_controller.redoc_google_fonts is True - assert google_font_cdn in response.text - - -def test_redoc_without_google_fonts(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - class OfflineOpenAPIController(OpenAPIController): - """test class for usage in a couple "offline" tests and for without Google fonts test.""" - - redoc_google_fonts = False - - offline_config = OpenAPIConfig(title="Litestar API", version="1.0.0", openapi_controller=OfflineOpenAPIController) - with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: - response = client.get("/schema/redoc") - assert "fonts.googleapis.com" not in response.text - - -def test_openapi_redoc_offline(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - class OfflineOpenAPIController(OpenAPIController): - """test class for usage in a couple "offline" tests and for without Google fonts test.""" - - redoc_js_url = "https://offline_location/redoc.standalone.js" - - offline_config = OpenAPIConfig(title="Litestar API", version="1.0.0", openapi_controller=OfflineOpenAPIController) - with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: - response = client.get("/schema/redoc") - assert OfflineOpenAPIController.redoc_js_url in response.text - - -def test_openapi_swagger_offline(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - class OfflineOpenAPIController(OpenAPIController): - """test class for usage in a couple "offline" tests and for without Google fonts test.""" - - swagger_css_url = "https://offline_location/swagger-ui-css" - swagger_ui_bundle_js_url = "https://offline_location/swagger-ui-bundle.js" - swagger_ui_standalone_preset_js_url = "https://offline_location/swagger-ui-standalone-preset.js" - - offline_config = OpenAPIConfig(title="Litestar API", version="1.0.0", openapi_controller=OfflineOpenAPIController) - with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: - response = client.get("/schema/swagger") - assert OfflineOpenAPIController.swagger_css_url in response.text - assert OfflineOpenAPIController.swagger_ui_bundle_js_url in response.text - assert OfflineOpenAPIController.swagger_ui_standalone_preset_js_url in response.text - - -def test_openapi_stoplight_elements_offline( - person_controller: Type[Controller], pet_controller: Type[Controller] -) -> None: - class OfflineOpenAPIController(OpenAPIController): - """test class for usage in a couple "offline" tests and for without Google fonts test.""" - - stoplight_elements_css_url = "https://offline_location/spotlight-styles.mins.css" - stoplight_elements_js_url = "https://offline_location/spotlight-web-components.min.js" - - offline_config = OpenAPIConfig(title="Litestar API", version="1.0.0", openapi_controller=OfflineOpenAPIController) - with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: - response = client.get("/schema/elements") - assert OfflineOpenAPIController.stoplight_elements_css_url in response.text - assert OfflineOpenAPIController.stoplight_elements_js_url in response.text - - -def test_openapi_rapidoc_offline(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - class OfflineOpenAPIController(OpenAPIController): - """test class for usage in a couple "offline" tests and for without Google fonts test.""" - - rapidoc_js_url = "https://offline_location/rapidoc-min.js" - - offline_config = OpenAPIConfig(title="Litestar API", version="1.0.0", openapi_controller=OfflineOpenAPIController) - with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: - response = client.get("/schema/rapidoc") - assert OfflineOpenAPIController.rapidoc_js_url in response.text - - -@pytest.mark.parametrize("root_path", root_paths) -def test_openapi_root(root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client([person_controller, pet_controller], root_path=root_path) as client: - response = client.get("/schema") - assert response.status_code == HTTP_200_OK - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -@pytest.mark.parametrize("root_path", root_paths) -def test_openapi_redoc(root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client([person_controller, pet_controller], root_path=root_path) as client: - response = client.get("/schema/redoc") - assert response.status_code == HTTP_200_OK - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -@pytest.mark.parametrize("root_path", root_paths) -def test_openapi_swagger(root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client([person_controller, pet_controller], root_path=root_path) as client: - response = client.get("/schema/swagger") - assert response.status_code == HTTP_200_OK - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -@pytest.mark.parametrize("root_path", root_paths) -def test_openapi_swagger_caching_schema( - root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller] -) -> None: - with create_test_client([person_controller, pet_controller], root_path=root_path) as client: - # Make sure that the schema is tweaked for swagger as the openapi version is changed. - # Because schema can get cached, make sure that getting a different schema type before works. - client.get("/schema/redoc") # Cache the schema - response = client.get("/schema/swagger") # Request swagger, should use a different cache - - assert "3.1.0" in response.text # Make sure the injected version is still there - assert response.status_code == HTTP_200_OK - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -@pytest.mark.parametrize("root_path", root_paths) -def test_openapi_stoplight_elements( - root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller] -) -> None: - with create_test_client([person_controller, pet_controller], root_path=root_path) as client: - response = client.get("/schema/elements/") - assert response.status_code == HTTP_200_OK - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -@pytest.mark.parametrize("root_path", root_paths) -def test_openapi_rapidoc(root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client([person_controller, pet_controller], root_path=root_path) as client: - response = client.get("/schema/rapidoc") - assert response.status_code == HTTP_200_OK - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -def test_openapi_root_not_allowed(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client( - [person_controller, pet_controller], - openapi_config=OpenAPIConfig( - title="Litestar API", - version="1.0.0", - enabled_endpoints={"swagger", "elements", "openapi.json", "openapi.yaml", "openapi.yml"}, - ), - ) as client: - response = client.get("/schema") - assert response.status_code == HTTP_404_NOT_FOUND - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -def test_openapi_redoc_not_allowed(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client( - [person_controller, pet_controller], - openapi_config=OpenAPIConfig( - title="Litestar API", - version="1.0.0", - enabled_endpoints={"swagger", "elements", "openapi.json", "openapi.yaml", "openapi.yml"}, - ), - ) as client: - response = client.get("/schema/redoc") - assert response.status_code == HTTP_404_NOT_FOUND - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -def test_openapi_swagger_not_allowed(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client( - [person_controller, pet_controller], - openapi_config=OpenAPIConfig( - title="Litestar API", - version="1.0.0", - enabled_endpoints={"redoc", "elements", "openapi.json", "openapi.yaml", "openapi.yml"}, - ), - ) as client: - response = client.get("/schema/swagger") - assert response.status_code == HTTP_404_NOT_FOUND - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -def test_openapi_stoplight_elements_not_allowed( - person_controller: Type[Controller], pet_controller: Type[Controller] -) -> None: - with create_test_client( - [person_controller, pet_controller], - openapi_config=OpenAPIConfig( - title="Litestar API", - version="1.0.0", - enabled_endpoints={"redoc", "swagger", "openapi.json", "openapi.yaml", "openapi.yml"}, - ), - ) as client: - response = client.get("/schema/elements/") - assert response.status_code == HTTP_404_NOT_FOUND - assert response.headers["content-type"].startswith(MediaType.HTML.value) - - -def test_openapi_rapidoc_not_allowed(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None: - with create_test_client( - [person_controller, pet_controller], - openapi_config=OpenAPIConfig( - title="Litestar API", - version="1.0.0", - enabled_endpoints={"swagger", "elements", "openapi.json", "openapi.yaml", "openapi.yml"}, - ), - ) as client: - response = client.get("/schema/rapidoc") - assert response.status_code == HTTP_404_NOT_FOUND - assert response.headers["content-type"].startswith(MediaType.HTML.value) diff --git a/tests/unit/test_openapi/test_endpoints.py b/tests/unit/test_openapi/test_endpoints.py new file mode 100644 index 0000000000..7ad694fb70 --- /dev/null +++ b/tests/unit/test_openapi/test_endpoints.py @@ -0,0 +1,487 @@ +from typing import List, Optional, Type + +import pytest + +from litestar import Controller +from litestar.enums import MediaType +from litestar.openapi.config import OpenAPIConfig +from litestar.openapi.controller import OpenAPIController +from litestar.openapi.plugins import ( + JsonRenderPlugin, + OpenAPIRenderPlugin, + RapidocRenderPlugin, + RedocRenderPlugin, + ScalarRenderPlugin, + StoplightRenderPlugin, + SwaggerRenderPlugin, +) +from litestar.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND +from litestar.testing import create_test_client + +root_paths: List[str] = ["", "/part1", "/part1/part2"] + + +@pytest.fixture() +def config(openapi_controller: Optional[Type[OpenAPIController]]) -> OpenAPIConfig: + return OpenAPIConfig(title="Litestar API", version="1.0.0", openapi_controller=openapi_controller) + + +def test_default_redoc_cdn_urls( + person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + default_redoc_version = "next" + default_redoc_js_bundle = f"https://cdn.jsdelivr.net/npm/redoc@{default_redoc_version}/bundles/redoc.standalone.js" + with create_test_client([person_controller, pet_controller], openapi_config=config) as client: + response = client.get("/schema/redoc") + assert default_redoc_js_bundle in response.text + + +def test_default_swagger_ui_cdn_urls( + person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + default_swagger_ui_version = "5.1.3" + default_swagger_bundles = [ + f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui.css", + f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui-bundle.js", + f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui-standalone-preset.js", + ] + with create_test_client([person_controller, pet_controller], openapi_config=config) as client: + response = client.get("/schema/swagger") + assert all(cdn_url in response.text for cdn_url in default_swagger_bundles) + + +def test_default_stoplight_elements_cdn_urls( + person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + default_stoplight_elements_version = "7.7.18" + default_stoplight_elements_bundles = [ + f"https://unpkg.com/@stoplight/elements@{default_stoplight_elements_version}/styles.min.css", + f"https://unpkg.com/@stoplight/elements@{default_stoplight_elements_version}/web-components.min.js", + ] + with create_test_client([person_controller, pet_controller], openapi_config=config) as client: + response = client.get("/schema/elements") + assert all(cdn_url in response.text for cdn_url in default_stoplight_elements_bundles) + + +def test_default_rapidoc_elements_cdn_urls( + person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + default_rapidoc_version = "9.3.4" + default_rapidoc_bundles = [f"https://unpkg.com/rapidoc@{default_rapidoc_version}/dist/rapidoc-min.js"] + with create_test_client([person_controller, pet_controller], openapi_config=config) as client: + response = client.get("/schema/rapidoc") + assert all(cdn_url in response.text for cdn_url in default_rapidoc_bundles) + + +def test_redoc_with_google_fonts( + person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + google_font_cdn = "https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" + with create_test_client([person_controller, pet_controller], openapi_config=config) as client: + response = client.get("/schema/redoc") + assert google_font_cdn in response.text + + +@pytest.mark.parametrize( + ("openapi_controller", "render_plugins"), + [ + (type("OfflineOpenAPIController", (OpenAPIController,), {"redoc_google_fonts": False}), []), + (None, [RedocRenderPlugin(google_fonts=False)]), + ], +) +def test_redoc_without_google_fonts( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], + render_plugins: List[OpenAPIRenderPlugin], +) -> None: + offline_config = OpenAPIConfig( + title="Litestar API", version="1.0.0", openapi_controller=openapi_controller, render_plugins=render_plugins + ) + with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: + response = client.get("/schema/redoc") + assert "fonts.googleapis.com" not in response.text + + +OFFLINE_LOCATION_JS_URL = "https://offline_location/bundle.js" +OFFLINE_LOCATION_CSS_URL = "https://offline_location/bundle.css" +OFFLINE_LOCATION_OTHER_URL = "https://offline_location/bundle.other" + + +@pytest.mark.parametrize( + ("openapi_controller", "render_plugins"), + [ + (type("OfflineOpenAPIController", (OpenAPIController,), {"redoc_js_url": OFFLINE_LOCATION_JS_URL}), []), + (None, [RedocRenderPlugin(js_url=OFFLINE_LOCATION_JS_URL)]), + ], +) +def test_openapi_redoc_offline( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], + render_plugins: List[OpenAPIRenderPlugin], +) -> None: + offline_config = OpenAPIConfig( + title="Litestar API", version="1.0.0", openapi_controller=openapi_controller, render_plugins=render_plugins + ) + with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: + response = client.get("/schema/redoc") + assert OFFLINE_LOCATION_JS_URL in response.text + + +@pytest.mark.parametrize( + ("openapi_controller", "render_plugins"), + [ + ( + type( + "OfflineOpenAPIController", + (OpenAPIController,), + { + "swagger_ui_bundle_js_url": OFFLINE_LOCATION_JS_URL, + "swagger_css_url": OFFLINE_LOCATION_CSS_URL, + "swagger_ui_standalone_preset_js_url": OFFLINE_LOCATION_OTHER_URL, + }, + ), + [], + ), + ( + None, + [ + SwaggerRenderPlugin( + js_url=OFFLINE_LOCATION_JS_URL, + css_url=OFFLINE_LOCATION_CSS_URL, + standalone_preset_js_url=OFFLINE_LOCATION_OTHER_URL, + ) + ], + ), + ], +) +def test_openapi_swagger_offline( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], + render_plugins: List[OpenAPIRenderPlugin], +) -> None: + offline_config = OpenAPIConfig( + title="Litestar API", version="1.0.0", openapi_controller=openapi_controller, render_plugins=render_plugins + ) + with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: + response = client.get("/schema/swagger") + assert all( + offline_url in response.text + for offline_url in [OFFLINE_LOCATION_JS_URL, OFFLINE_LOCATION_CSS_URL, OFFLINE_LOCATION_OTHER_URL] + ) + + +@pytest.mark.parametrize( + ("openapi_controller", "render_plugins"), + [ + ( + type( + "OfflineOpenAPIController", + (OpenAPIController,), + { + "stoplight_elements_css_url": OFFLINE_LOCATION_CSS_URL, + "stoplight_elements_js_url": OFFLINE_LOCATION_JS_URL, + }, + ), + [], + ), + ( + None, + [ + StoplightRenderPlugin( + js_url=OFFLINE_LOCATION_JS_URL, + css_url=OFFLINE_LOCATION_CSS_URL, + ) + ], + ), + ], +) +def test_openapi_stoplight_elements_offline( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], + render_plugins: List[OpenAPIRenderPlugin], +) -> None: + offline_config = OpenAPIConfig( + title="Litestar API", version="1.0.0", openapi_controller=openapi_controller, render_plugins=render_plugins + ) + with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: + response = client.get("/schema/elements") + assert all(offline_url in response.text for offline_url in [OFFLINE_LOCATION_JS_URL, OFFLINE_LOCATION_CSS_URL]) + + +@pytest.mark.parametrize( + ("openapi_controller", "render_plugins"), + [ + ( + None, + [ + ScalarRenderPlugin( + js_url=OFFLINE_LOCATION_JS_URL, + css_url=OFFLINE_LOCATION_CSS_URL, + ) + ], + ), + ], +) +def test_openapi_scalar_offline( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], + render_plugins: List[OpenAPIRenderPlugin], +) -> None: + offline_config = OpenAPIConfig( + title="Litestar API", version="1.0.0", openapi_controller=openapi_controller, render_plugins=render_plugins + ) + with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: + response = client.get("/schema/scalar") + assert all(offline_url in response.text for offline_url in [OFFLINE_LOCATION_JS_URL, OFFLINE_LOCATION_CSS_URL]) + + +@pytest.mark.parametrize( + ("openapi_controller", "render_plugins"), + [ + (type("OfflineOpenAPIController", (OpenAPIController,), {"rapidoc_js_url": OFFLINE_LOCATION_JS_URL}), []), + (None, [RapidocRenderPlugin(js_url=OFFLINE_LOCATION_JS_URL)]), + ], +) +def test_openapi_rapidoc_offline( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], + render_plugins: List[OpenAPIRenderPlugin], +) -> None: + offline_config = OpenAPIConfig( + title="Litestar API", version="1.0.0", openapi_controller=openapi_controller, render_plugins=render_plugins + ) + with create_test_client([person_controller, pet_controller], openapi_config=offline_config) as client: + response = client.get("/schema/rapidoc") + assert OFFLINE_LOCATION_JS_URL in response.text + + +@pytest.mark.parametrize("root_path", root_paths) +def test_openapi_root( + root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + with create_test_client([person_controller, pet_controller], root_path=root_path, openapi_config=config) as client: + response = client.get("/schema") + assert response.status_code == HTTP_200_OK + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +@pytest.mark.parametrize("root_path", root_paths) +def test_openapi_redoc( + root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + with create_test_client([person_controller, pet_controller], root_path=root_path, openapi_config=config) as client: + response = client.get("/schema/redoc") + assert response.status_code == HTTP_200_OK + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +@pytest.mark.parametrize("root_path", root_paths) +def test_openapi_swagger( + root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + with create_test_client([person_controller, pet_controller], root_path=root_path, openapi_config=config) as client: + response = client.get("/schema/swagger") + assert response.status_code == HTTP_200_OK + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +@pytest.mark.parametrize("root_path", root_paths) +def test_openapi_swagger_caching_schema( + root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + with create_test_client([person_controller, pet_controller], root_path=root_path, openapi_config=config) as client: + # Make sure that the schema is tweaked for swagger as the openapi version is changed. + # Because schema can get cached, make sure that getting a different schema type before works. + client.get("/schema/redoc") # Cache the schema + response = client.get("/schema/swagger") # Request swagger, should use a different cache + + assert "3.1.0" in response.text # Make sure the injected version is still there + assert response.status_code == HTTP_200_OK + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +@pytest.mark.parametrize("root_path", root_paths) +def test_openapi_stoplight_elements( + root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + with create_test_client([person_controller, pet_controller], root_path=root_path, openapi_config=config) as client: + response = client.get("/schema/elements/") + assert response.status_code == HTTP_200_OK + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +@pytest.mark.parametrize("root_path", root_paths) +def test_openapi_rapidoc( + root_path: str, person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig +) -> None: + with create_test_client([person_controller, pet_controller], root_path=root_path, openapi_config=config) as client: + response = client.get("/schema/rapidoc") + assert response.status_code == HTTP_200_OK + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +def test_openapi_root_not_allowed( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], +) -> None: + with create_test_client( + [person_controller, pet_controller], + openapi_config=OpenAPIConfig( + title="Litestar API", + version="1.0.0", + enabled_endpoints={"swagger", "elements", "openapi.json", "openapi.yaml", "openapi.yml"}, + openapi_controller=openapi_controller, + ), + ) as client: + response = client.get("/schema") + assert response.status_code == HTTP_404_NOT_FOUND + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +def test_openapi_redoc_not_allowed( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], +) -> None: + with create_test_client( + [person_controller, pet_controller], + openapi_config=OpenAPIConfig( + title="Litestar API", + version="1.0.0", + enabled_endpoints={"swagger", "elements", "openapi.json", "openapi.yaml", "openapi.yml"}, + openapi_controller=openapi_controller, + ), + ) as client: + response = client.get("/schema/redoc") + assert response.status_code == HTTP_404_NOT_FOUND + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +def test_openapi_swagger_not_allowed( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], +) -> None: + with create_test_client( + [person_controller, pet_controller], + openapi_config=OpenAPIConfig( + title="Litestar API", + version="1.0.0", + enabled_endpoints={"redoc", "elements", "openapi.json", "openapi.yaml", "openapi.yml"}, + openapi_controller=openapi_controller, + ), + ) as client: + response = client.get("/schema/swagger") + assert response.status_code == HTTP_404_NOT_FOUND + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +def test_openapi_stoplight_elements_not_allowed( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], +) -> None: + with create_test_client( + [person_controller, pet_controller], + openapi_config=OpenAPIConfig( + title="Litestar API", + version="1.0.0", + enabled_endpoints={"redoc", "swagger", "openapi.json", "openapi.yaml", "openapi.yml"}, + openapi_controller=openapi_controller, + ), + ) as client: + response = client.get("/schema/elements/") + assert response.status_code == HTTP_404_NOT_FOUND + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +def test_openapi_rapidoc_not_allowed( + person_controller: Type[Controller], + pet_controller: Type[Controller], + openapi_controller: Optional[Type[OpenAPIController]], +) -> None: + with create_test_client( + [person_controller, pet_controller], + openapi_config=OpenAPIConfig( + title="Litestar API", + version="1.0.0", + enabled_endpoints={"swagger", "elements", "openapi.json", "openapi.yaml", "openapi.yml"}, + openapi_controller=openapi_controller, + ), + ) as client: + response = client.get("/schema/rapidoc") + assert response.status_code == HTTP_404_NOT_FOUND + assert response.headers["content-type"].startswith(MediaType.HTML.value) + + +@pytest.mark.parametrize( + ("render_plugins",), + [ + ([],), + ([RedocRenderPlugin()],), + ([RedocRenderPlugin(), JsonRenderPlugin()],), + ([JsonRenderPlugin(path="/custom_path")],), + ([JsonRenderPlugin(path=["/openapi.json", "/custom_path"])],), + ], +) +def test_json_plugin_always_enabled(render_plugins: List["OpenAPIRenderPlugin"]) -> None: + """We assume that an '/openapi.json' path is available in many of the openapi render plugins. + + This test ensures that the json plugin is always enabled, even if the user has not explicitly + included it in the render_plugins list. + """ + + openapi_config = OpenAPIConfig(title="my title", version="1.0.0", render_plugins=render_plugins) + with create_test_client([], openapi_config=openapi_config) as client: + response = client.get("/schema/openapi.json") + assert response.status_code == HTTP_200_OK + + +def test_default_plugin_explicit_path() -> None: + config = OpenAPIConfig(title="my title", version="1.0.0", render_plugins=[SwaggerRenderPlugin(path="/")]) + with create_test_client([], openapi_config=config) as client: + response = client.get("/schema/") + assert response.status_code == HTTP_200_OK + + response = client.get("/schema/swagger") + assert response.status_code == HTTP_404_NOT_FOUND + + +def test_default_plugin_backward_compatibility() -> None: + config = OpenAPIConfig(title="my title", version="1.0.0") + with create_test_client([], openapi_config=config) as client: + response = client.get("/schema/") + assert response.status_code == HTTP_200_OK + + response = client.get("/schema/redoc") + assert response.status_code == HTTP_200_OK + + +def test_default_plugin_backward_compatibility_not_found() -> None: + config = OpenAPIConfig(title="my title", version="1.0.0", enabled_endpoints={"redoc"}, root_schema_site="swagger") + with create_test_client([], openapi_config=config) as client: + response = client.get("/schema/") + assert response.status_code == HTTP_404_NOT_FOUND + + response = client.get("/schema/swagger") + assert response.status_code == HTTP_404_NOT_FOUND + + response = client.get("/schema/redoc") + assert response.status_code == HTTP_200_OK + + +def test_default_plugin_future_compatibility() -> None: + config = OpenAPIConfig(title="my title", version="1.0.0", render_plugins=[SwaggerRenderPlugin()]) + with create_test_client([], openapi_config=config) as client: + response = client.get("/schema/") + assert response.status_code == HTTP_200_OK + + response = client.get("/schema/swagger") + assert response.status_code == HTTP_200_OK diff --git a/tests/unit/test_openapi/test_integration.py b/tests/unit/test_openapi/test_integration.py index 42c5b2e4a3..df67fdf6ac 100644 --- a/tests/unit/test_openapi/test_integration.py +++ b/tests/unit/test_openapi/test_integration.py @@ -11,7 +11,6 @@ from litestar import Controller, Litestar, delete, get, patch, post from litestar._openapi.plugin import OpenAPIPlugin -from litestar.app import DEFAULT_OPENAPI_CONFIG from litestar.enums import MediaType, OpenAPIMediaType, ParamType from litestar.openapi import OpenAPIConfig, OpenAPIController from litestar.openapi.spec import Parameter as OpenAPIParameter @@ -23,12 +22,22 @@ CREATE_EXAMPLES_VALUES = (True, False) -@pytest.mark.parametrize("create_examples", CREATE_EXAMPLES_VALUES) +@pytest.fixture(params=[True, False]) +def create_examples(request: pytest.FixtureRequest) -> bool: + return request.param # type: ignore[no-any-return] + + @pytest.mark.parametrize("schema_path", ["/schema/openapi.yaml", "/schema/openapi.yml"]) def test_openapi( - person_controller: type[Controller], pet_controller: type[Controller], create_examples: bool, schema_path: str + person_controller: type[Controller], + pet_controller: type[Controller], + create_examples: bool, + schema_path: str, + openapi_controller: type[OpenAPIController] | None, ) -> None: - openapi_config = OpenAPIConfig("Example API", "1.0.0", create_examples=create_examples) + openapi_config = OpenAPIConfig( + "Example API", "1.0.0", create_examples=create_examples, openapi_controller=openapi_controller + ) with create_test_client([person_controller, pet_controller], openapi_config=openapi_config) as client: assert client.app.openapi_schema openapi_schema = client.app.openapi_schema @@ -42,11 +51,15 @@ def test_openapi( assert response.content.decode("utf-8") == yaml.dump(schema_json) -@pytest.mark.parametrize("create_examples", CREATE_EXAMPLES_VALUES) def test_openapi_json( - person_controller: type[Controller], pet_controller: type[Controller], create_examples: bool + person_controller: type[Controller], + pet_controller: type[Controller], + create_examples: bool, + openapi_controller: type[OpenAPIController] | None, ) -> None: - openapi_config = OpenAPIConfig("Example API", "1.0.0", create_examples=create_examples) + openapi_config = OpenAPIConfig( + "Example API", "1.0.0", create_examples=create_examples, openapi_controller=openapi_controller + ) with create_test_client([person_controller, pet_controller], openapi_config=openapi_config) as client: assert client.app.openapi_schema openapi_schema = client.app.openapi_schema @@ -63,10 +76,15 @@ def test_openapi_json( "endpoint, schema_path", [("openapi.yaml", "/schema/openapi.yaml"), ("openapi.yml", "/schema/openapi.yml")] ) def test_openapi_yaml_not_allowed( - endpoint: str, schema_path: str, person_controller: type[Controller], pet_controller: type[Controller] + endpoint: str, + schema_path: str, + person_controller: type[Controller], + pet_controller: type[Controller], + openapi_controller: type[OpenAPIController] | None, ) -> None: - openapi_config = DEFAULT_OPENAPI_CONFIG - openapi_config.enabled_endpoints.discard(endpoint) + openapi_config = OpenAPIConfig( + "Example API", "1.0.0", enabled_endpoints=set(), openapi_controller=openapi_controller + ) with create_test_client([person_controller, pet_controller], openapi_config=openapi_config) as client: assert client.app.openapi_schema @@ -77,8 +95,13 @@ def test_openapi_yaml_not_allowed( def test_openapi_json_not_allowed(person_controller: type[Controller], pet_controller: type[Controller]) -> None: - openapi_config = DEFAULT_OPENAPI_CONFIG - openapi_config.enabled_endpoints.discard("openapi.json") + # only tested with the OpenAPIController, b/c new router based approach always serves `openapi.json`. + openapi_config = OpenAPIConfig( + "Example API", + "1.0.0", + enabled_endpoints=set(), + openapi_controller=OpenAPIController, + ) with create_test_client([person_controller, pet_controller], openapi_config=openapi_config) as client: assert client.app.openapi_schema @@ -88,8 +111,10 @@ def test_openapi_json_not_allowed(person_controller: type[Controller], pet_contr assert response.status_code == HTTP_404_NOT_FOUND -def test_openapi_custom_path() -> None: - openapi_config = OpenAPIConfig(title="my title", version="1.0.0", path="/custom_schema_path") +def test_openapi_custom_path(openapi_controller: type[OpenAPIController] | None) -> None: + openapi_config = OpenAPIConfig( + title="my title", version="1.0.0", path="/custom_schema_path", openapi_controller=openapi_controller + ) with create_test_client([], openapi_config=openapi_config) as client: response = client.get("/schema") assert response.status_code == HTTP_404_NOT_FOUND @@ -101,8 +126,10 @@ def test_openapi_custom_path() -> None: assert response.status_code == HTTP_200_OK -def test_openapi_normalizes_custom_path() -> None: - openapi_config = OpenAPIConfig(title="my title", version="1.0.0", path="custom_schema_path") +def test_openapi_normalizes_custom_path(openapi_controller: type[OpenAPIController] | None) -> None: + openapi_config = OpenAPIConfig( + title="my title", version="1.0.0", path="custom_schema_path", openapi_controller=openapi_controller + ) with create_test_client([], openapi_config=openapi_config) as client: response = client.get("/custom_schema_path/openapi.json") assert response.status_code == HTTP_200_OK @@ -145,8 +172,7 @@ class CustomOpenAPIController(OpenAPIController): assert response.status_code == HTTP_200_OK -@pytest.mark.parametrize("create_examples", CREATE_EXAMPLES_VALUES) -def test_msgspec_schema_generation(create_examples: bool) -> None: +def test_msgspec_schema_generation(create_examples: bool, openapi_controller: type[OpenAPIController] | None) -> None: class Lookup(msgspec.Struct): id: Annotated[ str, @@ -168,6 +194,7 @@ async def example_route() -> Lookup: title="Example API", version="1.0.0", create_examples=create_examples, + openapi_controller=openapi_controller, ), signature_types=[Lookup], ) as client: @@ -184,7 +211,7 @@ async def example_route() -> Lookup: } -def test_schema_for_optional_path_parameter() -> None: +def test_schema_for_optional_path_parameter(openapi_controller: type[OpenAPIController] | None) -> None: @get(path=["/", "/{test_message:str}"], media_type=MediaType.TEXT, sync_to_thread=False) def handler(test_message: Optional[str]) -> str: # noqa: UP007 return test_message or "no message" @@ -195,6 +222,7 @@ def handler(test_message: Optional[str]) -> str: # noqa: UP007 title="Example API", version="1.0.0", create_examples=True, + openapi_controller=openapi_controller, ), ) as client: response = client.get("/schema/openapi.json") @@ -214,7 +242,7 @@ class Foo(Generic[T]): foo: T -def test_with_generic_class() -> None: +def test_with_generic_class(openapi_controller: type[OpenAPIController] | None) -> None: @get("/foo-str", sync_to_thread=False) def handler_foo_str() -> Foo[str]: return Foo("") @@ -228,6 +256,7 @@ def handler_foo_int() -> Foo[int]: openapi_config=OpenAPIConfig( title="Example API", version="1.0.0", + openapi_controller=openapi_controller, ), ) as client: response = client.get("/schema/openapi.json") diff --git a/tests/unit/test_openapi/test_render_plugins.py b/tests/unit/test_openapi/test_render_plugins.py new file mode 100644 index 0000000000..947c57297f --- /dev/null +++ b/tests/unit/test_openapi/test_render_plugins.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from litestar.openapi.plugins import OpenAPIRenderPlugin +from litestar.testing import RequestFactory + + +def test_render_plugin_get_openapi_json_route() -> None: + request = RequestFactory().get() + assert OpenAPIRenderPlugin.get_openapi_json_route(request) == "/schema/openapi.json" From 285d3f4aed6af3252ec2f86d8720b34323ea7b7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sun, 17 Mar 2024 10:04:20 +0100 Subject: [PATCH 12/19] feat(DTO): Enable codegen backend by default (#3215) * Enable DTO codegen backend by default * Update docs --- docs/usage/dto/0-basic-use.rst | 9 +++++---- litestar/app.py | 15 +++++++++++++-- litestar/dto/base_dto.py | 4 +--- litestar/handlers/base.py | 11 ----------- tests/unit/test_app.py | 7 ++++++- tests/unit/test_dto/test_integration.py | 13 +++++++++++++ 6 files changed, 38 insertions(+), 21 deletions(-) diff --git a/docs/usage/dto/0-basic-use.rst b/docs/usage/dto/0-basic-use.rst index e5a4df46cd..46693489e6 100644 --- a/docs/usage/dto/0-basic-use.rst +++ b/docs/usage/dto/0-basic-use.rst @@ -78,11 +78,12 @@ DTOs can similarly be defined on :class:`Routers ` and Improving performance with the codegen backend ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. admonition:: Experimental feature - :class: danger +.. note:: - This is an experimental feature and should be approached with caution. It may - behave unexpectedly, contain bugs and may disappear again in a future version. + This feature was introduced in ``2.2.0`` and hidden behind the ``DTO_CODEGEN`` + feature flag. As of ``2.8.0`` it is considered stable and enabled by default. It can + still be disabled selectively by using the + ``DTOConfig(experimental_codegen_backend=True)`` override. The DTO backend is the part that does the heavy lifting for all the DTO features. It diff --git a/litestar/app.py b/litestar/app.py index 995e90d6bd..3dbdbf676f 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -3,6 +3,7 @@ import inspect import logging import os +import warnings from contextlib import ( AbstractAsyncContextManager, AsyncExitStack, @@ -20,12 +21,13 @@ from litestar._openapi.plugin import OpenAPIPlugin from litestar._openapi.schema_generation import openapi_schema_plugins from litestar.config.allowed_hosts import AllowedHostsConfig -from litestar.config.app import AppConfig +from litestar.config.app import AppConfig, ExperimentalFeatures from litestar.config.response_cache import ResponseCacheConfig from litestar.connection import Request, WebSocket from litestar.datastructures.state import State from litestar.events.emitter import BaseEventEmitterBackend, SimpleEventEmitter from litestar.exceptions import ( + LitestarWarning, MissingDependencyException, NoRouteMatchFoundException, ) @@ -55,7 +57,6 @@ if TYPE_CHECKING: from typing_extensions import Self - from litestar.config.app import ExperimentalFeatures from litestar.config.compression import CompressionConfig from litestar.config.cors import CORSConfig from litestar.config.csrf import CSRFConfig @@ -392,6 +393,16 @@ def __init__( self._lifespan_managers.append(store) self._server_lifespan_managers = [p.server_lifespan for p in config.plugins or [] if isinstance(p, CLIPlugin)] self.experimental_features = frozenset(config.experimental_features or []) + if ExperimentalFeatures.DTO_CODEGEN in self.experimental_features: + warnings.warn( + "Use of redundant experimental feature flag DTO_CODEGEN. " + "DTO codegen backend is enabled by default since Litestar 2.8. The " + "DTO_CODEGEN feature flag can be safely removed from the configuration " + "and will be removed in version 3.0.", + category=LitestarWarning, + stacklevel=2, + ) + self.get_logger: GetLogger = get_logger_placeholder self.logger: Logger | None = None self.routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = [] diff --git a/litestar/dto/base_dto.py b/litestar/dto/base_dto.py index 991b09fa41..9124d1f35c 100644 --- a/litestar/dto/base_dto.py +++ b/litestar/dto/base_dto.py @@ -178,9 +178,7 @@ def create_for_field_definition( ) if backend_cls is None: - backend_cls = DTOCodegenBackend if cls.config.experimental_codegen_backend else DTOBackend - elif backend_cls is DTOCodegenBackend and cls.config.experimental_codegen_backend is False: - backend_cls = DTOBackend + backend_cls = DTOCodegenBackend if cls.config.experimental_codegen_backend is not False else DTOBackend backend_context[key] = backend_cls( # type: ignore[literal-required] dto_factory=cls, diff --git a/litestar/handlers/base.py b/litestar/handlers/base.py index 9dbb70e28b..1803af39a5 100644 --- a/litestar/handlers/base.py +++ b/litestar/handlers/base.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence, cast from litestar._signature import SignatureModel -from litestar.config.app import ExperimentalFeatures from litestar.di import Provide from litestar.dto import DTOData from litestar.exceptions import ImproperlyConfiguredException @@ -32,7 +31,6 @@ from litestar.connection import ASGIConnection from litestar.controller import Controller from litestar.dto import AbstractDTO - from litestar.dto._backend import DTOBackend from litestar.params import ParameterKwarg from litestar.router import Router from litestar.types import AnyCallable, AsyncAnyCallable, ExceptionHandler @@ -442,13 +440,6 @@ def resolve_signature_namespace(self) -> dict[str, Any]: self._resolved_signature_namespace = ns return cast("dict[str, Any]", self._resolved_signature_namespace) - def _get_dto_backend_cls(self) -> type[DTOBackend] | None: - if ExperimentalFeatures.DTO_CODEGEN in self.app.experimental_features: - from litestar.dto._codegen_backend import DTOCodegenBackend - - return DTOCodegenBackend - return None - def resolve_data_dto(self) -> type[AbstractDTO] | None: """Resolve the data_dto by starting from the route handler and moving up. If a handler is found it is returned, otherwise None is set. @@ -478,7 +469,6 @@ def resolve_data_dto(self) -> type[AbstractDTO] | None: data_dto.create_for_field_definition( field_definition=self.parsed_data_field, handler_id=self.handler_id, - backend_cls=self._get_dto_backend_cls(), ) self._resolved_data_dto = data_dto @@ -512,7 +502,6 @@ def resolve_return_dto(self) -> type[AbstractDTO] | None: return_dto.create_for_field_definition( field_definition=self.parsed_return_field, handler_id=self.handler_id, - backend_cls=self._get_dto_backend_cls(), ) self._resolved_return_dto = return_dto else: diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index cc7ee81826..809b8abe1e 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -14,7 +14,7 @@ from pytest import MonkeyPatch from litestar import Litestar, MediaType, Request, Response, get -from litestar.config.app import AppConfig +from litestar.config.app import AppConfig, ExperimentalFeatures from litestar.config.response_cache import ResponseCacheConfig from litestar.contrib.sqlalchemy.plugins import SQLAlchemySerializationPlugin from litestar.datastructures import MutableScopeHeaders, State @@ -445,3 +445,8 @@ async def hook_b(app: Litestar) -> None: assert events[1] == "ctx_1" assert events[2] == "hook_a" assert events[3] == "hook_b" + + +def test_use_dto_codegen_feature_flag_warns() -> None: + with pytest.warns(LitestarWarning, match="Use of redundant experimental feature flag DTO_CODEGEN"): + Litestar(experimental_features=[ExperimentalFeatures.DTO_CODEGEN]) diff --git a/tests/unit/test_dto/test_integration.py b/tests/unit/test_dto/test_integration.py index 6e90055e78..f9d7ba4d57 100644 --- a/tests/unit/test_dto/test_integration.py +++ b/tests/unit/test_dto/test_integration.py @@ -153,3 +153,16 @@ def handler(data: Model) -> Model: backend = handler.resolve_data_dto()._dto_backends[handler.handler_id]["data_backend"] # type: ignore[union-attr] assert isinstance(backend, DTOBackend) + + +def test_use_codegen_backend_by_default(ModelDataDTO: type[AbstractDTO]) -> None: + ModelDataDTO.config = DTOConfig() + + @post(dto=ModelDataDTO, signature_types=[Model]) + def handler(data: Model) -> Model: + return data + + Litestar(route_handlers=[handler]) + + backend = handler.resolve_data_dto()._dto_backends[handler.handler_id]["data_backend"] # type: ignore[union-attr] + assert isinstance(backend, DTOBackend) From 588040674f6996e9475994ee674dee71b57d6d24 Mon Sep 17 00:00:00 2001 From: kedod <35638715+kedod@users.noreply.github.com> Date: Sun, 17 Mar 2024 19:03:33 +0100 Subject: [PATCH 13/19] feat: Added precedence of CLI parameters over envs (#3190) * feat: Added precedence of CLI parameters over envs * Update docs/usage/cli.rst Co-authored-by: Peter Schutt * Remove redundant LitestarEnv fields and fix tests * Update docs/usage/cli.rst * Update litestar/cli/commands/core.py * Update docs/usage/cli.rst * Update docs/usage/cli.rst * Update litestar/cli/commands/core.py --------- Co-authored-by: kedod Co-authored-by: Peter Schutt Co-authored-by: Jacob Coffee --- docs/usage/cli.rst | 3 + litestar/cli/_utils.py | 28 ----- litestar/cli/commands/core.py | 77 +++++++++----- tests/unit/test_cli/test_core_commands.py | 114 ++++++++++++++++++++- tests/unit/test_cli/test_env_resolution.py | 24 ----- 5 files changed, 165 insertions(+), 81 deletions(-) diff --git a/docs/usage/cli.rst b/docs/usage/cli.rst index b52463d315..8333d1694f 100644 --- a/docs/usage/cli.rst +++ b/docs/usage/cli.rst @@ -95,6 +95,9 @@ The ``run`` command executes a Litestar application using `uvicorn LitestarEnv: @@ -119,31 +108,14 @@ def from_env(cls, app_path: str | None, app_dir: Path | None = None) -> Litestar loaded_app = _autodiscover_app(cwd) port = getenv("LITESTAR_PORT") - web_concurrency = getenv("WEB_CONCURRENCY") - uds = getenv("LITESTAR_UNIX_DOMAIN_SOCKET") - fd = getenv("LITESTAR_FILE_DESCRIPTOR") - reload_dirs = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_DIRS", "").split(",") if s) or None - reload_include = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_INCLUDES", "").split(",") if s) or None - reload_exclude = tuple(s.strip() for s in getenv("LITESTAR_RELOAD_EXCLUDES", "").split(",") if s) or None return cls( app_path=loaded_app.app_path, app=loaded_app.app, - debug=_bool_from_env("LITESTAR_DEBUG"), host=getenv("LITESTAR_HOST"), port=int(port) if port else None, - uds=uds, - fd=int(fd) if fd else None, - reload=_bool_from_env("LITESTAR_RELOAD"), - reload_dirs=reload_dirs, - reload_include=reload_include, - reload_exclude=reload_exclude, - web_concurrency=int(web_concurrency) if web_concurrency else None, is_app_factory=loaded_app.is_factory, cwd=cwd, - certfile_path=getenv("LITESTAR_SSL_CERT_PATH"), - keyfile_path=getenv("LITESTAR_SSL_KEY_PATH"), - create_self_signed_cert=_bool_from_env("LITESTAR_CREATE_SELF_SIGNED_CERT"), ) diff --git a/litestar/cli/commands/core.py b/litestar/cli/commands/core.py index 06e5d4175f..e3273f1d82 100644 --- a/litestar/cli/commands/core.py +++ b/litestar/cli/commands/core.py @@ -103,6 +103,15 @@ def _run_uvicorn_in_subprocess( ) +class CommaSplittedPath(click.Path): + """A Click Path that splits the input string by commas. + + .. versionadded:: 2.8.0 + """ + + envvar_list_splitter = "," + + @command(name="version") @option("-s", "--short", help="Exclude release level and serial information", is_flag=True, default=False) def version_command(short: bool) -> None: @@ -120,15 +129,32 @@ def info_command(app: Litestar) -> None: @command(name="run") -@option("-r", "--reload", help="Reload server on changes", default=False, is_flag=True) -@option("-R", "--reload-dir", help="Directories to watch for file changes", multiple=True) +@option("-r", "--reload", help="Reload server on changes", default=False, is_flag=True, envvar="LITESTAR_RELOAD") +@option( + "-R", + "--reload-dir", + help="Directories to watch for file changes", + type=CommaSplittedPath(), + multiple=True, + envvar="LITESTAR_RELOAD_DIRS", +) @option( - "-I", "--reload-include", help="Glob patterns for files to include when watching for file changes", multiple=True + "-I", + "--reload-include", + help="Glob patterns for files to include when watching for file changes", + type=CommaSplittedPath(), + multiple=True, + envvar="LITESTAR_RELOAD_INCLUDES", ) @option( - "-E", "--reload-exclude", help="Glob patterns for files to exclude when watching for file changes", multiple=True + "-E", + "--reload-exclude", + help="Glob patterns for files to exclude when watching for file changes", + type=CommaSplittedPath(), + multiple=True, + envvar="LITESTAR_RELOAD_EXCLUDES", ) -@option("-p", "--port", help="Serve under this port", type=int, default=8000, show_default=True) +@option("-p", "--port", help="Serve under this port", type=int, default=8000, show_default=True, envvar="LITESTAR_PORT") @option( "-W", "--wc", @@ -137,8 +163,9 @@ def info_command(app: Litestar) -> None: type=click.IntRange(min=1, max=multiprocessing.cpu_count() + 1), show_default=True, default=1, + envvar="WEB_CONCURRENCY", ) -@option("-H", "--host", help="Server under this host", default="127.0.0.1", show_default=True) +@option("-H", "--host", help="Server under this host", default="127.0.0.1", show_default=True, envvar="LITESTAR_HOST") @option( "-F", "--fd", @@ -147,16 +174,26 @@ def info_command(app: Litestar) -> None: type=int, default=None, show_default=True, + envvar="LITESTAR_FILE_DESCRIPTOR", ) -@option("-U", "--uds", "--unix-domain-socket", help="Bind to a UNIX domain socket.", default=None, show_default=True) -@option("-d", "--debug", help="Run app in debug mode", is_flag=True) -@option("-P", "--pdb", "--use-pdb", help="Drop into PDB on an exception", is_flag=True) -@option("--ssl-certfile", help="Location of the SSL cert file", default=None) -@option("--ssl-keyfile", help="Location of the SSL key file", default=None) +@option( + "-U", + "--uds", + "--unix-domain-socket", + help="Bind to a UNIX domain socket.", + default=None, + show_default=True, + envvar="LITESTAR_UNIX_DOMAIN_SOCKET", +) +@option("-d", "--debug", help="Run app in debug mode", is_flag=True, envvar="LITESTAR_DEBUG") +@option("-P", "--pdb", "--use-pdb", help="Drop into PDB on an exception", is_flag=True, envvar="LITESTAR_PDB") +@option("--ssl-certfile", help="Location of the SSL cert file", default=None, envvar="LITESTAR_SSL_CERT_PATH") +@option("--ssl-keyfile", help="Location of the SSL key file", default=None, envvar="LITESTAR_SSL_KEY_PATH") @option( "--create-self-signed-cert", help="If certificate and key are not found at specified locations, create a self-signed certificate and a key", is_flag=True, + envvar="LITESTAR_CREATE_SELF_SIGNED_CERT", ) def run_command( reload: bool, @@ -207,20 +244,8 @@ def run_command( env: LitestarEnv = ctx.obj app = env.app - reload_dirs = env.reload_dirs or reload_dir - reload_include = env.reload_include or reload_include - reload_exclude = env.reload_exclude or reload_exclude - - host = env.host or host - port = env.port if env.port is not None else port - fd = env.fd if env.fd is not None else fd - uds = env.uds or uds - reload = env.reload or reload or bool(reload_dirs) or bool(reload_include) or bool(reload_exclude) - workers = env.web_concurrency or wc - - ssl_certfile = ssl_certfile or env.certfile_path - ssl_keyfile = ssl_keyfile or env.keyfile_path - create_self_signed_cert = create_self_signed_cert or env.create_self_signed_cert + reload = reload or bool(reload_dir) or bool(reload_include) or bool(reload_exclude) + workers = wc certfile_path, keyfile_path = ( create_ssl_files(ssl_certfile, ssl_keyfile, host) @@ -263,7 +288,7 @@ def run_command( port=port, workers=workers, reload=reload, - reload_dirs=reload_dirs, + reload_dirs=reload_dir, reload_include=reload_include, reload_exclude=reload_exclude, fd=fd, diff --git a/tests/unit/test_cli/test_core_commands.py b/tests/unit/test_cli/test_core_commands.py index 82c7727d93..e6c476bf9c 100644 --- a/tests/unit/test_cli/test_core_commands.py +++ b/tests/unit/test_cli/test_core_commands.py @@ -3,7 +3,7 @@ import re import sys from pathlib import Path -from typing import Callable, Generator, List, Optional, Tuple +from typing import Callable, Generator, List, Literal, Optional, Tuple, Union from unittest.mock import MagicMock import pytest @@ -117,7 +117,7 @@ def test_run_command( else: uds = None - if fd: + if fd is not None: if set_in_env: monkeypatch.setenv("LITESTAR_FILE_DESCRIPTOR", str(fd)) else: @@ -151,7 +151,6 @@ def test_run_command( args.extend([f"--reload-exclude={s}" for s in reload_exclude]) path = create_app_file(custom_app_file or "app.py", directory=app_dir) - result = runner.invoke(cli_command, args) assert result.exception is None @@ -257,6 +256,115 @@ def test_run_command_with_app_factory( ) +@pytest.mark.parametrize( + "cli, env, expected", + ( + ( + ("--reload", True), + ("LITESTAR_RELOAD", False), + "--reload", + ), + ( + ("--reload-dir", [".", "../somewhere_else"]), + ("LITESTAR_RELOAD_DIRS", ["../somewhere_else3", "../somewhere_else2"]), + ["--reload-dir=.", "--reload-dir=../somewhere_else"], + ), + ( + ("--reload-include", ["*.rst", "*.yml"]), + ("LITESTAR_RELOAD_INCLUDES", ["*.rst2", "*.yml2"]), + ["--reload-include=*.rst", "--reload-include=*.yml"], + ), + ( + ("--reload-exclude", ["*.rst", "*.yml"]), + ("LITESTAR_RELOAD_EXCLUDES", ["*.rst2", "*.yml2"]), + ["--reload-exclude=*.rst", "--reload-exclude=*.yml"], + ), + ( + ("--wc", 2), + ("WEB_CONCURRENCY", 4), + "--workers=2", + ), + ( + ("--fd", 0), + ("LITESTAR_FILE_DESCRIPTOR", 1), + "--fd=0", + ), + ( + ("--uds", "/run/uvicorn/litestar_test.sock"), + ("LITESTAR_UNIX_DOMAIN_SOCKET", "/run/uvicorn/litestar_test2.sock"), + "--uds=/run/uvicorn/litestar_test.sock", + ), + ( + ("-d", True), + ("LITESTAR_DEBUG", False), + ("LITESTAR_DEBUG", "1"), + ), + ( + ("--pdb", True), + ("LITESTAR_PDB", False), + ("LITESTAR_PDB", "1"), + ), + ), +) +def test_run_command_arguments_precedence( + cli: Tuple[str, Union[Literal[True], List[str], str]], + env: Tuple[str, Union[Literal[True], List[str], str]], + expected: str, + runner: CliRunner, + monkeypatch: MonkeyPatch, + mock_subprocess_run: MagicMock, + tmp_project_dir: Path, + create_app_file: CreateAppFileFixture, + mock_uvicorn_run: MagicMock, +) -> None: + args = [] + args.extend(["--app", f"{Path('my_app.py').stem}:app"]) + args.extend(["--app-dir", str(Path(tmp_project_dir / "custom_subfolder"))]) + args.extend(["run"]) + create_app_file("my_app.py", directory="custom_subfolder") + + env_name, env_value = env + cli_name, cli_value = cli + + if env_name: + if isinstance(env_value, list): + monkeypatch.setenv(env_name, "".join(env_value)) + else: + monkeypatch.setenv(env_name, env_value) # type: ignore[arg-type] # pyright: ignore (reportGeneralTypeIssues) + + if cli_name: + if cli_value is True: + args.append(cli_name) + elif isinstance(cli_value, list): + for value in cli_value: + args.extend([cli_name, value]) + else: + args.extend([cli_name, cli_value]) + + result = runner.invoke(cli_command, args) + + assert result.exception is None + assert result.exit_code == 0 + + if cli_name in ["--fd", "--uds"]: + mock_subprocess_run.assert_not_called() + if isinstance(expected, list): # type: ignore[unreachable] + assert all(_ in mock_uvicorn_run.call_args_list[0].args[0] for _ in expected) # type: ignore[unreachable] + else: + assert mock_uvicorn_run.call_args_list[0].kwargs.get(cli_name.strip("--")) == cli_value + + elif cli_name in ["-d", "--pdb"]: + assert os.environ.get(expected[0]) == expected[1] + + else: + mock_subprocess_run.assert_called_once() + + if isinstance(expected, list): # type: ignore[unreachable] + assert all(_ in mock_subprocess_run.call_args_list[0].args[0] for _ in expected) # type: ignore[unreachable] + else: + assert expected in mock_subprocess_run.call_args_list[0].args[0] + + @pytest.fixture() def unset_env() -> Generator[None, None, None]: initial_env = {**os.environ} diff --git a/tests/unit/test_cli/test_env_resolution.py b/tests/unit/test_cli/test_env_resolution.py index 3b89f6c288..35c701f249 100644 --- a/tests/unit/test_cli/test_env_resolution.py +++ b/tests/unit/test_cli/test_env_resolution.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional import pytest from _pytest.monkeypatch import MonkeyPatch @@ -13,29 +12,6 @@ pytestmark = pytest.mark.xdist_group("cli_autodiscovery") -@pytest.mark.parametrize("env_name,attr_name", [("LITESTAR_DEBUG", "debug"), ("LITESTAR_RELOAD", "reload")]) -@pytest.mark.parametrize( - "env_value,expected_value", - [("true", True), ("True", True), ("1", True), ("0", False), (None, False)], -) -def test_litestar_env_from_env_booleans( - monkeypatch: MonkeyPatch, - app_file: Path, - attr_name: str, - env_name: str, - env_value: Optional[str], - expected_value: bool, -) -> None: - monkeypatch.delenv(env_name, raising=False) - if env_value is not None: - monkeypatch.setenv(env_name, env_value) - - env = LitestarEnv.from_env(f"{app_file.stem}:app") - - assert getattr(env, attr_name) is expected_value - assert isinstance(env.app, Litestar) - - def test_litestar_env_from_env_port(monkeypatch: MonkeyPatch, app_file: Path) -> None: env = LitestarEnv.from_env(f"{app_file.stem}:app") assert env.port is None From fc80dfc193c917d6f1e3b1840e78201829678306 Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Mon, 18 Mar 2024 09:04:26 -0500 Subject: [PATCH 14/19] feat: only print when terminal is `TTY` enabled (#3219) --- litestar/cli/_utils.py | 25 +++++-- litestar/cli/commands/core.py | 3 +- tests/unit/test_cli/test_core_commands.py | 84 +++++++++++++++++++---- 3 files changed, 92 insertions(+), 20 deletions(-) diff --git a/litestar/cli/_utils.py b/litestar/cli/_utils.py index 29ed4826ea..ac082868eb 100644 --- a/litestar/cli/_utils.py +++ b/litestar/cli/_utils.py @@ -101,8 +101,8 @@ def from_env(cls, app_path: str | None, app_dir: Path | None = None) -> Litestar if app_path and getenv("LITESTAR_APP") is None: os.environ["LITESTAR_APP"] = app_path if app_path: - if not quiet_console: - console.print(f"Using {app_name} from env: [bright_blue]{app_path!r}") + if not quiet_console and isatty(): + console.print(f"Using {app_name} app from env: [bright_blue]{app_path!r}") loaded_app = _load_app_from_path(app_path) else: loaded_app = _autodiscover_app(cwd) @@ -303,6 +303,8 @@ def _autodiscovery_paths(base_dir: Path, arbitrary: bool = True) -> Generator[Pa def _autodiscover_app(cwd: Path) -> LoadedApp: + app_name = getenv("LITESTAR_APP_NAME") or "Litestar" + quiet_console = getenv("LITESTAR_QUIET_CONSOLE") or False for file_path in _autodiscovery_paths(cwd): import_path = _path_to_dotted_path(file_path.relative_to(cwd)) module = importlib.import_module(import_path) @@ -314,13 +316,15 @@ def _autodiscover_app(cwd: Path) -> LoadedApp: if isinstance(value, Litestar): app_string = f"{import_path}:{attr}" os.environ["LITESTAR_APP"] = app_string - console.print(f"Using Litestar app from [bright_blue]{app_string}") + if not quiet_console and isatty(): + console.print(f"Using {app_name} app from [bright_blue]{app_string}") return LoadedApp(app=value, app_path=app_string, is_factory=False) if hasattr(module, "create_app"): app_string = f"{import_path}:create_app" os.environ["LITESTAR_APP"] = app_string - console.print(f"Using Litestar factory [bright_blue]{app_string}") + if not quiet_console and isatty(): + console.print(f"Using {app_name} factory from [bright_blue]{app_string}") return LoadedApp(app=module.create_app(), app_path=app_string, is_factory=True) for attr, value in module.__dict__.items(): @@ -334,10 +338,11 @@ def _autodiscover_app(cwd: Path) -> LoadedApp: if return_annotation in ("Litestar", Litestar): app_string = f"{import_path}:{attr}" os.environ["LITESTAR_APP"] = app_string - console.print(f"Using Litestar factory [bright_blue]{app_string}") + if not quiet_console and sys.stdout.isatty(): + console.print(f"Using {app_name} factory from [bright_blue]{app_string}") return LoadedApp(app=value(), app_path=f"{app_string}", is_factory=True) - raise LitestarCLIException("Could not find a Litestar app or factory") + raise LitestarCLIException(f"Could not find {app_name} instance or factory") def _format_is_enabled(value: Any) -> str: @@ -544,3 +549,11 @@ def remove_default_schema_routes( else openapi_config.openapi_controller.path ) return remove_routes_with_patterns(routes, (schema_path,)) + + +def isatty() -> bool: + """Detect if a terminal is TTY enabled. + + This is a convenience wrapper around the built in system methods. This allows for easier testing of TTY/non-TTY modes. + """ + return sys.stdout.isatty() diff --git a/litestar/cli/commands/core.py b/litestar/cli/commands/core.py index e3273f1d82..803634bc1c 100644 --- a/litestar/cli/commands/core.py +++ b/litestar/cli/commands/core.py @@ -18,6 +18,7 @@ LitestarEnv, console, create_ssl_files, + isatty, remove_default_schema_routes, remove_routes_with_patterns, show_app_info, @@ -253,7 +254,7 @@ def run_command( else validate_ssl_file_paths(ssl_certfile, ssl_keyfile) ) - if not quiet_console: + if not quiet_console and isatty(): console.rule("[yellow]Starting server process", align="left") show_app_info(app) with _server_lifespan(app): diff --git a/tests/unit/test_cli/test_core_commands.py b/tests/unit/test_cli/test_core_commands.py index e6c476bf9c..3400dd328b 100644 --- a/tests/unit/test_cli/test_core_commands.py +++ b/tests/unit/test_cli/test_core_commands.py @@ -14,6 +14,7 @@ from litestar import __version__ as litestar_version from litestar.cli import _utils +from litestar.cli.commands import core from litestar.cli.main import litestar_group as cli_command from litestar.exceptions import LitestarWarning @@ -57,8 +58,11 @@ def mock_show_app_info(mocker: MockerFixture) -> MagicMock: (False, None, None, None, 2), ], ) +@pytest.mark.parametrize("tty_enabled", [True, False]) +@pytest.mark.parametrize("quiet_console", [True, False]) def test_run_command( mock_show_app_info: MagicMock, + mocker: MockerFixture, runner: CliRunner, monkeypatch: MonkeyPatch, reload: Optional[bool], @@ -74,10 +78,17 @@ def test_run_command( custom_app_file: Optional[Path], create_app_file: CreateAppFileFixture, set_in_env: bool, + tty_enabled: bool, + quiet_console: bool, mock_subprocess_run: MagicMock, mock_uvicorn_run: MagicMock, tmp_project_dir: Path, ) -> None: + monkeypatch.delenv("LITESTAR_QUIET_CONSOLE", raising=False) + if quiet_console: + monkeypatch.setenv("LITESTAR_QUIET_CONSOLE", "true") + mocker.patch.object(core, "isatty", return_value=tty_enabled) + mocker.patch.object(_utils, "isatty", return_value=tty_enabled) args = [] if custom_app_file: args.extend(["--app", f"{custom_app_file.stem}:app"]) @@ -194,9 +205,14 @@ def test_run_command( ssl_keyfile=None, ) - mock_show_app_info.assert_called_once() + if tty_enabled and not quiet_console: + mock_show_app_info.assert_called_once() + else: + mock_show_app_info.assert_not_called() +@pytest.mark.parametrize("quiet_console", [True, False]) +@pytest.mark.parametrize("tty_enabled", [True, False]) @pytest.mark.parametrize( "file_name,file_content,factory_name", [ @@ -213,12 +229,20 @@ def test_run_command_with_autodiscover_app_factory( file_content: str, factory_name: str, patch_autodiscovery_paths: Callable[[List[str]], None], + tty_enabled: bool, + quiet_console: bool, create_app_file: CreateAppFileFixture, + mocker: MockerFixture, + monkeypatch: MonkeyPatch, ) -> None: + monkeypatch.delenv("LITESTAR_QUIET_CONSOLE", raising=False) + if quiet_console: + monkeypatch.setenv("LITESTAR_QUIET_CONSOLE", "true") + mocker.patch.object(core, "isatty", return_value=tty_enabled) + mocker.patch.object(_utils, "isatty", return_value=tty_enabled) patch_autodiscovery_paths([file_name]) path = create_app_file(file_name, content=file_content) result = runner.invoke(cli_command, "run") - assert result.exception is None assert result.exit_code == 0 @@ -232,11 +256,28 @@ def test_run_command_with_autodiscover_app_factory( ssl_certfile=None, ssl_keyfile=None, ) + if tty_enabled and not quiet_console: + assert len(result.output) > 0 + else: + assert len(result.output) == 0 +@pytest.mark.parametrize("quiet_console", [True, False]) +@pytest.mark.parametrize("tty_enabled", [True, False]) def test_run_command_with_app_factory( - runner: CliRunner, mock_uvicorn_run: MagicMock, create_app_file: CreateAppFileFixture + runner: CliRunner, + mock_uvicorn_run: MagicMock, + create_app_file: CreateAppFileFixture, + tty_enabled: bool, + quiet_console: bool, + mocker: MockerFixture, + monkeypatch: MonkeyPatch, ) -> None: + monkeypatch.delenv("LITESTAR_QUIET_CONSOLE", raising=False) + if quiet_console: + monkeypatch.setenv("LITESTAR_QUIET_CONSOLE", "true") + mocker.patch.object(core, "isatty", return_value=tty_enabled) + mocker.patch.object(_utils, "isatty", return_value=tty_enabled) path = create_app_file("_create_app_with_path.py", content=CREATE_APP_FILE_CONTENT) app_path = f"{path.stem}:create_app" result = runner.invoke(cli_command, ["--app", app_path, "run"]) @@ -254,6 +295,10 @@ def test_run_command_with_app_factory( ssl_certfile=None, ssl_keyfile=None, ) + if tty_enabled and not quiet_console: + assert len(result.output) > 0 + else: + assert len(result.output) == 0 @pytest.mark.parametrize( @@ -390,9 +435,15 @@ def test_run_command_debug( @pytest.mark.usefixtures("mock_uvicorn_run", "unset_env") def test_run_command_quiet_console( - app_file: Path, runner: CliRunner, monkeypatch: MonkeyPatch, create_app_file: CreateAppFileFixture + app_file: Path, + mocker: MockerFixture, + runner: CliRunner, + monkeypatch: MonkeyPatch, + create_app_file: CreateAppFileFixture, ) -> None: - console = Console(file=io.StringIO()) + mocker.patch.object(core, "isatty", return_value=True) + mocker.patch.object(_utils, "isatty", return_value=True) + console = Console(file=io.StringIO(), force_interactive=True) monkeypatch.setattr(_utils, "console", console) path = create_app_file("_create_app_with_path.py", content=CREATE_APP_FILE_CONTENT) @@ -401,10 +452,10 @@ def test_run_command_quiet_console( result = runner.invoke(cli_command, ["--app", app_path, "run"]) assert result.exit_code == 0 normal_output = console.file.getvalue() # type: ignore[attr-defined] - assert "Using Litestar from env:" in normal_output + assert "Using Litestar app from env:" in normal_output assert "Starting server process" in result.stdout del result - console = Console(file=io.StringIO()) + console = Console(file=io.StringIO(), force_interactive=True) monkeypatch.setattr(_utils, "console", console) monkeypatch.setenv("LITESTAR_QUIET_CONSOLE", "1") assert os.getenv("LITESTAR_QUIET_CONSOLE") == "1" @@ -412,15 +463,22 @@ def test_run_command_quiet_console( assert result.exit_code == 0 quiet_output = console.file.getvalue() # type: ignore[attr-defined] assert "Starting server process" not in result.stdout - assert "Using Litestar from env:" not in quiet_output + assert "Using Litestar app from env:" not in quiet_output console.clear() @pytest.mark.usefixtures("mock_uvicorn_run", "unset_env") def test_run_command_custom_app_name( - app_file: Path, runner: CliRunner, monkeypatch: MonkeyPatch, create_app_file: CreateAppFileFixture + app_file: Path, + runner: CliRunner, + monkeypatch: MonkeyPatch, + create_app_file: CreateAppFileFixture, + mocker: MockerFixture, ) -> None: - console = Console(file=io.StringIO()) + mocker.patch.object(core, "isatty", return_value=True) + mocker.patch.object(_utils, "isatty", return_value=True) + + console = Console(file=io.StringIO(), force_interactive=True) monkeypatch.setattr(_utils, "console", console) path = create_app_file("_create_app_with_path.py", content=CREATE_APP_FILE_CONTENT) @@ -429,15 +487,15 @@ def test_run_command_custom_app_name( result = runner.invoke(cli_command, ["--app", app_path, "run"]) assert result.exit_code == 0 _output = console.file.getvalue() # type: ignore[attr-defined] - assert "Using Litestar from env:" in _output - console = Console(file=io.StringIO()) + assert "Using Litestar app from env:" in _output + console = Console(file=io.StringIO(), force_interactive=True) monkeypatch.setattr(_utils, "console", console) monkeypatch.setenv("LITESTAR_APP_NAME", "My Stuff") assert os.getenv("LITESTAR_APP_NAME") == "My Stuff" result = runner.invoke(cli_command, ["--app", app_path, "run"]) assert result.exit_code == 0 _output = console.file.getvalue() # type: ignore[attr-defined] - assert "Using My Stuff from env:" in _output + assert "Using My Stuff app from env:" in _output @pytest.mark.usefixtures("mock_uvicorn_run", "unset_env") From e64aee4c1d3e3fe202091954c047258844520d29 Mon Sep 17 00:00:00 2001 From: Tuukka Mustonen Date: Tue, 19 Mar 2024 17:34:33 +0200 Subject: [PATCH 15/19] feat: Support `schema_extra` in `Parameter` and `Body` (#3204) * feat: Support `schema_extra` in `Parameter` and `Body` (#3022) This adds sort of a backdoor for modifying the generated OpenAPI spec. The value is given as `dict[str, Any]` where the key must match with the keyword parameter name in `Schema`. The values are used to override items in the generated `Schema` object, so they must be in correct types (ie. not in dictionary/json format). The values are added at main level, without recursive merging (because we're adjusting `Schema` object and not a dictionary). Recursive merge would be much more work. Chose not to implement the same for `ResponseSpec` because response models are generated as schema components, while `ResponseSpec` can be locally different. Handling the logic of creating new components when `schema_extra` is passed in `ResponseSpec` would be extra effort, and isn't probably as important as being able to adjust the inbound parameters, which are actually validated (and for which the documentation is even more important, than for the response). * Update litestar/params.py Co-authored-by: Jacob Coffee * Update litestar/params.py Co-authored-by: Jacob Coffee * Update litestar/params.py Co-authored-by: Jacob Coffee --------- Co-authored-by: Jacob Coffee --- litestar/_openapi/schema_generation/schema.py | 10 ++++- litestar/params.py | 19 +++++++++ tests/unit/test_openapi/test_parameters.py | 42 ++++++++++++++++++- tests/unit/test_openapi/test_request_body.py | 28 ++++++++++++- 4 files changed, 96 insertions(+), 3 deletions(-) diff --git a/litestar/_openapi/schema_generation/schema.py b/litestar/_openapi/schema_generation/schema.py index 0b7d6c6fbf..43e5d328b4 100644 --- a/litestar/_openapi/schema_generation/schema.py +++ b/litestar/_openapi/schema_generation/schema.py @@ -50,7 +50,7 @@ from litestar.exceptions import ImproperlyConfiguredException from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType from litestar.openapi.spec.schema import Schema, SchemaDataContainer -from litestar.params import BodyKwarg, ParameterKwarg +from litestar.params import BodyKwarg, KwargDefinition, ParameterKwarg from litestar.plugins import OpenAPISchemaPlugin from litestar.types import Empty from litestar.types.builtin_types import NoneType @@ -569,6 +569,14 @@ def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schem if getattr(schema, schema_key, None) is None: setattr(schema, schema_key, value) + if isinstance(field.kwarg_definition, KwargDefinition) and (extra := field.kwarg_definition.schema_extra): + for schema_key, value in extra.items(): + if not hasattr(schema, schema_key): + raise ValueError( + f"`schema_extra` declares key `{schema_key}` which does not exist in `Schema` object" + ) + setattr(schema, schema_key, value) + if not schema.examples and self.generate_examples: from litestar._openapi.schema_generation.examples import create_examples_for_field diff --git a/litestar/params.py b/litestar/params.py index bff010bb7e..c52389e0f4 100644 --- a/litestar/params.py +++ b/litestar/params.py @@ -112,6 +112,13 @@ class KwargDefinition: """A sequence of valid values.""" read_only: bool | None = field(default=None) """A boolean flag dictating whether this parameter is read only.""" + schema_extra: dict[str, Any] | None = field(default=None) + """Extensions to the generated schema. + + If set, will overwrite the matching fields in the generated schema. + + .. versionadded:: 2.8.0 + """ @property def is_constrained(self) -> bool: @@ -187,6 +194,7 @@ def Parameter( query: str | None = None, required: bool | None = None, title: str | None = None, + schema_extra: dict[str, Any] | None = None, ) -> Any: """Create an extended parameter kwarg definition. @@ -227,6 +235,10 @@ def Parameter( required: A boolean flag dictating whether this parameter is required. If set to False, None values will be allowed. Defaults to True. title: String value used in the title section of the OpenAPI schema for the given parameter. + schema_extra: Extensions to the generated schema. If set, will overwrite the matching fields in the generated + schema. + + .. versionadded:: 2.8.0 """ return ParameterKwarg( annotation=annotation, @@ -251,6 +263,7 @@ def Parameter( min_length=min_length, max_length=max_length, pattern=pattern, + schema_extra=schema_extra, ) @@ -294,6 +307,7 @@ def Body( multiple_of: float | None = None, pattern: str | None = None, title: str | None = None, + schema_extra: dict[str, Any] | None = None, ) -> Any: """Create an extended request body kwarg definition. @@ -331,6 +345,10 @@ def Body( pattern: A string representing a regex against which the given string will be matched. Equivalent to pattern in the OpenAPI specification. title: String value used in the title section of the OpenAPI schema for the given parameter. + schema_extra: Extensions to the generated schema. If set, will overwrite the matching fields in the generated + schema. + + .. versionadded:: 2.8.0 """ return BodyKwarg( media_type=media_type, @@ -352,6 +370,7 @@ def Body( max_length=max_length, pattern=pattern, multipart_form_part_limit=multipart_form_part_limit, + schema_extra=schema_extra, ) diff --git a/tests/unit/test_openapi/test_parameters.py b/tests/unit/test_openapi/test_parameters.py index 1bec6991bd..48cea53df0 100644 --- a/tests/unit/test_openapi/test_parameters.py +++ b/tests/unit/test_openapi/test_parameters.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Type, cast +from typing import TYPE_CHECKING, Any, List, Optional, Type, cast from uuid import UUID import pytest @@ -324,6 +324,46 @@ async def index( } +def test_parameter_schema_extra() -> None: + @get() + async def handler( + query1: Annotated[ + str, + Parameter( + schema_extra={ + "schema_not": Schema( + any_of=[ + Schema(type=OpenAPIType.STRING, pattern=r"^somePrefix:.*$"), + Schema(type=OpenAPIType.STRING, enum=["denied", "values"]), + ] + ), + } + ), + ], + ) -> Any: + return query1 + + @get() + async def error_handler(query1: Annotated[str, Parameter(schema_extra={"invalid": "dummy"})]) -> Any: + return query1 + + # Success + app = Litestar([handler]) + schema = app.openapi_schema.to_schema() + assert schema["paths"]["/"]["get"]["parameters"][0]["schema"]["not"] == { + "anyOf": [ + {"type": "string", "pattern": r"^somePrefix:.*$"}, + {"type": "string", "enum": ["denied", "values"]}, + ] + } + + # Attempt to pass invalid key + app = Litestar([error_handler]) + with pytest.raises(ValueError) as e: + app.openapi_schema + assert str(e.value).startswith("`schema_extra` declares key") + + def test_uuid_path_description_generation() -> None: # https://github.com/litestar-org/litestar/issues/2967 @get("str/{id:str}") diff --git a/tests/unit/test_openapi/test_request_body.py b/tests/unit/test_openapi/test_request_body.py index 1a9b072c96..05b196ebfc 100644 --- a/tests/unit/test_openapi/test_request_body.py +++ b/tests/unit/test_openapi/test_request_body.py @@ -3,8 +3,9 @@ from unittest.mock import ANY, MagicMock import pytest +from typing_extensions import Annotated -from litestar import Controller, Litestar, post +from litestar import Controller, Litestar, get, post from litestar._openapi.datastructures import OpenAPIContext from litestar._openapi.request_body import create_request_body from litestar.datastructures.upload_file import UploadFile @@ -56,6 +57,31 @@ def test_create_request_body(person_controller: Type[Controller], create_request assert request_body +def test_request_body_schema_extra() -> None: + @dataclass + class RequestBody: + foo: str + + @get() + async def handler( + body1: Annotated[ + RequestBody, + Body( + title="Default title", + schema_extra={ + "title": "Overridden title", + }, + ), + ], + ) -> Any: + return body1 + + app = Litestar([handler]) + schema = app.openapi_schema.to_schema() + resp = next(iter(schema["components"]["schemas"].values())) + assert resp["title"] == "Overridden title" + + def test_upload_single_file_schema_generation() -> None: @post(path="/file-upload") async def handle_file_upload( From 91008ea1449f2e70fe0f2c0075ffd1a42ec6dca3 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Fri, 22 Mar 2024 15:05:03 +1000 Subject: [PATCH 16/19] docs: improvement for openapi router config. (#3235) Adds deprecated directives for the deprecated parameters of the config. Adds some cross-references. --- litestar/openapi/config.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/litestar/openapi/config.py b/litestar/openapi/config.py index f1fd0acf12..ebb763eb47 100644 --- a/litestar/openapi/config.py +++ b/litestar/openapi/config.py @@ -115,7 +115,7 @@ class OpenAPIConfig: If no path is provided the default is ``/schema``. - Ignored if ``openapi_router`` is provided. + Ignored if :attr:`openapi_router` is provided. """ render_plugins: Sequence[OpenAPIRenderPlugin] = field(default=()) """Plugins for rendering OpenAPI documentation UIs.""" @@ -124,22 +124,32 @@ class OpenAPIConfig: If provided, ``path`` is ignored. - This parameter is also ignored if the deprecated :class:`OpenAPIConfig <.openapi.OpenAPIConfig>` ``openapi_controller`` kwarg is provided. + This parameter is also ignored if the deprecated :attr:`openapi_router <.openapi.OpenAPIConfig.openapi_controller>` + kwarg is provided. - The ``openapi_router`` is not required, but it can be passed to customize the configuration of the router used to serve the documentation endpoints. For example, you can add middleware or guards to the router. + :attr:`openapi_router` is not required, but it can be passed to customize the configuration of the router used to + serve the documentation endpoints. For example, you can add middleware or guards to the router. - Handlers to serve the OpenAPI schema and documentation sites are added to this router according - to the ``render_plugins`` attribute, so routes shouldn't be added that conflict with these. + Handlers to serve the OpenAPI schema and documentation sites are added to this router according to + :attr:`render_plugins`, so routes shouldn't be added that conflict with these. """ openapi_controller: type[OpenAPIController] | None = None """Controller for generating OpenAPI routes. Must be subclass of :class:`OpenAPIController `. + + .. deprecated:: v2.8.0 """ root_schema_site: Literal["redoc", "swagger", "elements", "rapidoc"] | None = None - """The static schema generator to use for the "root" path of ``/schema/``.""" + """The static schema generator to use for the "root" path of ``/schema/``. + + .. deprecated:: v2.8.0 + """ enabled_endpoints: set[str] | None = None - """A set of the enabled documentation sites and schema download endpoints.""" + """A set of the enabled documentation sites and schema download endpoints. + + .. deprecated:: v2.8.0 + """ def __post_init__(self) -> None: self._issue_deprecations() From af082e880a847648d2040e99a53e7f3019361622 Mon Sep 17 00:00:00 2001 From: harryle <64817481+haryle@users.noreply.github.com> Date: Sat, 23 Mar 2024 18:35:16 +1030 Subject: [PATCH 17/19] feat: add typevar expansion (#3242) * feat: add typevar expansion #3240 * chore: resolve all PR suggestion #3242 * chore: resolve import formatting * chore: resolve import formatting --- litestar/utils/signature.py | 5 +- litestar/utils/typing.py | 15 ++++++ tests/unit/test_utils/test_signature.py | 62 ++++++++++++++++++++++++- tests/unit/test_utils/test_typing.py | 25 ++++++++++ 4 files changed, 104 insertions(+), 3 deletions(-) diff --git a/litestar/utils/signature.py b/litestar/utils/signature.py index eb585990e0..c387b7f93b 100644 --- a/litestar/utils/signature.py +++ b/litestar/utils/signature.py @@ -14,7 +14,7 @@ from litestar.exceptions import ImproperlyConfiguredException from litestar.types import Empty from litestar.typing import FieldDefinition -from litestar.utils.typing import unwrap_annotation +from litestar.utils.typing import expand_type_var_in_type_hint, unwrap_annotation if TYPE_CHECKING: from typing import Sequence @@ -212,8 +212,9 @@ def from_fn(cls, fn: AnyCallable, signature_namespace: dict[str, Any]) -> Self: """ signature = Signature.from_callable(fn) fn_type_hints = get_fn_type_hints(fn, namespace=signature_namespace) + expanded_type_hints = expand_type_var_in_type_hint(fn_type_hints, signature_namespace) - return cls.from_signature(signature, fn_type_hints) + return cls.from_signature(signature, expanded_type_hints) @classmethod def from_signature(cls, signature: Signature, fn_type_hints: dict[str, type]) -> Self: diff --git a/litestar/utils/typing.py b/litestar/utils/typing.py index 9da6c2a6f6..cae445a90e 100644 --- a/litestar/utils/typing.py +++ b/litestar/utils/typing.py @@ -262,6 +262,21 @@ def get_type_hints_with_generics_resolved( return {n: _substitute_typevars(type_, typevar_map) for n, type_ in type_hints.items()} +def expand_type_var_in_type_hint(type_hint: dict[str, Any], namespace: dict[str, Any] | None) -> dict[str, Any]: + """Expand TypeVar for any parameters in type_hint + + Args: + type_hint: mapping of parameter to type obtained from calling `get_type_hints` or `get_fn_type_hints` + namespace: mapping of TypeVar to concrete type + + Returns: + type_hint with any TypeVar parameter expanded + """ + if namespace: + return {name: _substitute_typevars(hint, namespace) for name, hint in type_hint.items()} + return type_hint + + def _substitute_typevars(obj: Any, typevar_map: Mapping[Any, Any]) -> Any: if params := getattr(obj, "__parameters__", None): args = tuple(_substitute_typevars(typevar_map.get(p, p), typevar_map) for p in params) diff --git a/tests/unit/test_utils/test_signature.py b/tests/unit/test_utils/test_signature.py index 1f9767b8c2..8f0de7c38d 100644 --- a/tests/unit/test_utils/test_signature.py +++ b/tests/unit/test_utils/test_signature.py @@ -5,11 +5,12 @@ import inspect from inspect import Parameter from types import ModuleType -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, Generic, List, Optional, TypeVar, Union import pytest from typing_extensions import Annotated, NotRequired, Required, TypedDict, get_args, get_type_hints +from litestar import Controller, Router, post from litestar.exceptions import ImproperlyConfiguredException from litestar.file_system import BaseLocalFileSystem from litestar.static_files import StaticFiles @@ -20,6 +21,10 @@ from litestar.utils.signature import ParsedSignature, add_types_to_signature_namespace, get_fn_type_hints T = TypeVar("T") +U = TypeVar("U") + + +class ConcreteT: ... def test_get_fn_type_hints_asgi_app() -> None: @@ -161,3 +166,58 @@ def test_add_types_to_signature_namespace_with_existing_types_raises() -> None: """Test add_types_to_signature_namespace with existing types raises.""" with pytest.raises(ImproperlyConfiguredException): add_types_to_signature_namespace([int], {"int": int}) + + +@pytest.mark.parametrize( + ("namespace", "expected"), + ( + ({T: int}, {"data": int, "return": int}), + ({}, {"data": T, "return": T}), + ({T: ConcreteT}, {"data": ConcreteT, "return": ConcreteT}), + ), +) +def test_using_generics_in_fn_annotations(namespace: dict[str, Any], expected: dict[str, Any]) -> None: + @post(signature_namespace=namespace) + def create_item(data: T) -> T: + return data + + signature = create_item.parsed_fn_signature + actual = {"data": signature.parameters["data"].annotation, "return": signature.return_type.annotation} + assert actual == expected + + +class GenericController(Controller, Generic[T]): + model_class: T + + def __class_getitem__(cls, model_class: type) -> type: + cls_dict = {"model_class": model_class} + return type(f"GenericController[{model_class.__name__}", (cls,), cls_dict) + + def __init__(self, owner: Router) -> None: + super().__init__(owner) + self.signature_namespace[T] = self.model_class # type: ignore[misc] + + +class BaseController(GenericController[T]): + @post() + async def create(self, data: T) -> T: + return data + + +@pytest.mark.parametrize( + ("annotation_type", "expected"), + ( + (int, {"data": int, "return": int}), + (float, {"data": float, "return": float}), + (ConcreteT, {"data": ConcreteT, "return": ConcreteT}), + ), +) +def test_using_generics_in_controller_annotations(annotation_type: type, expected: dict[str, Any]) -> None: + class ConcreteController(BaseController[annotation_type]): # type: ignore[valid-type] + path = "/" + + controller_object = ConcreteController(owner=None) # type: ignore[arg-type] + + signature = controller_object.get_route_handlers()[0].parsed_fn_signature + actual = {"data": signature.parameters["data"].annotation, "return": signature.return_type.annotation} + assert actual == expected diff --git a/tests/unit/test_utils/test_typing.py b/tests/unit/test_utils/test_typing.py index 38f4a44174..5b9fd95b3f 100644 --- a/tests/unit/test_utils/test_typing.py +++ b/tests/unit/test_utils/test_typing.py @@ -9,6 +9,7 @@ from typing_extensions import Annotated from litestar.utils.typing import ( + expand_type_var_in_type_hint, get_origin_or_inner_type, get_type_hints_with_generics_resolved, make_non_optional_union, @@ -134,3 +135,27 @@ class NestedFoo(Generic[T]): ) def test_get_type_hints_with_generics(annotation: Any, expected_type_hints: dict[str, Any]) -> None: assert get_type_hints_with_generics_resolved(annotation, include_extras=True) == expected_type_hints + + +class ConcreteT: ... + + +@pytest.mark.parametrize( + ("type_hint", "namespace", "expected"), + ( + ({"arg1": T, "return": int}, {}, {"arg1": T, "return": int}), + ({"arg1": T, "return": int}, None, {"arg1": T, "return": int}), + ({"arg1": T, "return": int}, {U: ConcreteT}, {"arg1": T, "return": int}), + ({"arg1": T, "return": int}, {T: ConcreteT}, {"arg1": ConcreteT, "return": int}), + ({"arg1": T, "return": int}, {T: int}, {"arg1": int, "return": int}), + ({"arg1": int, "return": int}, {}, {"arg1": int, "return": int}), + ({"arg1": int, "return": int}, None, {"arg1": int, "return": int}), + ({"arg1": int, "return": int}, {T: int}, {"arg1": int, "return": int}), + ({"arg1": T, "return": T}, {T: ConcreteT}, {"arg1": ConcreteT, "return": ConcreteT}), + ({"arg1": T, "return": T}, {T: int}, {"arg1": int, "return": int}), + ), +) +def test_expand_type_var_in_type_hints( + type_hint: dict[str, Any], namespace: dict[str, Any] | None, expected: dict[str, Any] +) -> None: + assert expand_type_var_in_type_hint(type_hint, namespace) == expected From 3c2c5994b92a9cb517c3670e72f1a90698dca70a Mon Sep 17 00:00:00 2001 From: aranvir <75439739+aranvir@users.noreply.github.com> Date: Sat, 23 Mar 2024 21:50:14 +0100 Subject: [PATCH 18/19] docs: Add examples for auth `exclude` configuration (#3246) * adding draft for security exclusion docs * adding section to security toctree * Update docs/usage/security/excluding-and-including-endpoints.rst * Update docs/usage/security/excluding-and-including-endpoints.rst * Update docs/usage/security/excluding-and-including-endpoints.rst --------- Co-authored-by: Jacob Coffee --- .../excluding-and-including-endpoints.rst | 86 +++++++++++++++++++ docs/usage/security/index.rst | 1 + 2 files changed, 87 insertions(+) create mode 100644 docs/usage/security/excluding-and-including-endpoints.rst diff --git a/docs/usage/security/excluding-and-including-endpoints.rst b/docs/usage/security/excluding-and-including-endpoints.rst new file mode 100644 index 0000000000..c63142bbd1 --- /dev/null +++ b/docs/usage/security/excluding-and-including-endpoints.rst @@ -0,0 +1,86 @@ +Excluding and including endpoints +================================= + +Please make sure you read the :doc:`security backends documentation ` first for learning how to set up a security backend. This section focuses on configuring the ``exclude`` rule for those backends. + +There are multiple ways for including or excluding endpoints in the authentication flow. The default rules are configured in the ``Auth`` object used (subclass of :class:`~.security.base.AbstractSecurityConfig`). The examples below use :class:`~.security.session_auth.auth.SessionAuth` but it is the same for :class:`~.security.jwt.auth.JWTAuth` and :class:`~.security.jwt.auth.JWTCookieAuth`. + +Excluding routes +-------------------- + +The ``exclude`` argument takes a :class:`string ` or :class:`list` of :class:`strings ` that are interpreted as regex patterns. For example, the configuration below would apply authentication to all endpoints except those where the route starts with ``/login``, ``/signup``, or ``/schema``. Thus, one does not have to exclude ``/schema/swagger`` as well - it is included in the ``/schema`` pattern. + +This also means that passing ``/`` will disable authentication for all routes. + +.. code-block:: python + + session_auth = SessionAuth[User, ServerSideSessionBackend]( + retrieve_user_handler=retrieve_user_handler, + # we must pass a config for a session backend. + # all session backends are supported + session_backend_config=ServerSideSessionConfig(), + # exclude any URLs that should not have authentication. + # We exclude the documentation URLs, signup and login. + exclude=["/login", "/signup", "/schema"], + ) + ... + +Including routes +---------------- + +Since the exclusion rules are evaluated as regex, it is possible to pass a rule that inverts exclusion - meaning, no path but the one specified in the pattern will be protected by authentication. In the example below, only endpoints under the ``/secured`` route will require authentication - all other routes do not. + +.. code-block:: python + + ... + session_auth = SessionAuth[User, ServerSideSessionBackend]( + retrieve_user_handler=retrieve_user_handler, + # we must pass a config for a session backend. + # all session backends are supported + session_backend_config=ServerSideSessionConfig(), + # exclude any URLs that should not have authentication. + # We exclude the documentation URLs, signup and login. + exclude=[r"^(?!.*\/secured$).*$"], + ) + ... + +Exclude from auth +-------------------- +Sometimes, you might want to apply authentication to all endpoints under a route but a few selected. In this case, you can pass ``exclude_from_auth=True`` to the route handler as shown below. + +.. code-block:: python + + ... + @get("/secured") + def secured_route() -> Any: + ... + + @get("/unsecured", exclude_from_auth=True) + def unsecured_route() -> Any: + ... + ... + +You can set an alternative option key in the security configuration, e.g., you can use ``no_auth`` instead of ``exclude_from_auth``. + +.. code-block:: python + + ... + @get("/secured") + def secured_route() -> Any: + ... + + @get("/unsecured", no_auth=True) + def unsecured_route() -> Any: + ... + + session_auth = SessionAuth[User, ServerSideSessionBackend]( + retrieve_user_handler=retrieve_user_handler, + # we must pass a config for a session backend. + # all session backends are supported + session_backend_config=ServerSideSessionConfig(), + # exclude any URLs that should not have authentication. + # We exclude the documentation URLs, signup and login. + exclude=["/login", "/signup", "/schema"], + exclude_opt_key="no_auth" # default value is `exclude_from_auth` + ) + ... diff --git a/docs/usage/security/index.rst b/docs/usage/security/index.rst index 68e64c0ebb..e2daa1aa08 100644 --- a/docs/usage/security/index.rst +++ b/docs/usage/security/index.rst @@ -12,4 +12,5 @@ authentication and authorization. abstract-authentication-middleware security-backends guards + excluding-and-including-endpoints jwt From 150b642eb104e2c810c90941430748c64334123e Mon Sep 17 00:00:00 2001 From: kedod <35638715+kedod@users.noreply.github.com> Date: Wed, 27 Mar 2024 16:49:26 +0100 Subject: [PATCH 19/19] feat: Add LITESTAR_ prefix before WEB_CONCURRENCY env option (#3227) * feat: Add LITESTAR_ prefix for web concurrency env option * Replace depacrated with versionchanged directive * Change wc option description * Remove depracation warning --------- Co-authored-by: kedod --- docs/usage/cli.rst | 65 ++++++++++++----------- litestar/cli/commands/core.py | 2 +- tests/unit/test_cli/test_core_commands.py | 4 +- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/docs/usage/cli.rst b/docs/usage/cli.rst index 8333d1694f..c1de49b3fe 100644 --- a/docs/usage/cli.rst +++ b/docs/usage/cli.rst @@ -103,37 +103,40 @@ The ``run`` command executes a Litestar application using `uvicorn None: type=click.IntRange(min=1, max=multiprocessing.cpu_count() + 1), show_default=True, default=1, - envvar="WEB_CONCURRENCY", + envvar=["LITESTAR_WEB_CONCURRENCY", "WEB_CONCURRENCY"], ) @option("-H", "--host", help="Server under this host", default="127.0.0.1", show_default=True, envvar="LITESTAR_HOST") @option( diff --git a/tests/unit/test_cli/test_core_commands.py b/tests/unit/test_cli/test_core_commands.py index 3400dd328b..bc709a030a 100644 --- a/tests/unit/test_cli/test_core_commands.py +++ b/tests/unit/test_cli/test_core_commands.py @@ -139,7 +139,7 @@ def test_run_command( if web_concurrency is None: web_concurrency = 1 elif set_in_env: - monkeypatch.setenv("WEB_CONCURRENCY", str(web_concurrency)) + monkeypatch.setenv("LITESTAR_WEB_CONCURRENCY", str(web_concurrency)) else: args.extend(["--web-concurrency", str(web_concurrency)]) @@ -326,7 +326,7 @@ def test_run_command_with_app_factory( ), ( ("--wc", 2), - ("WEB_CONCURRENCY", 4), + ("LITESTAR_WEB_CONCURRENCY", 4), "--workers=2", ), (