diff --git a/docs/examples/contrib/sqlalchemy/sqlalchemy_declarative_models.py b/docs/examples/contrib/sqlalchemy/sqlalchemy_declarative_models.py index 984181dc3a..cae8ff5ab1 100644 --- a/docs/examples/contrib/sqlalchemy/sqlalchemy_declarative_models.py +++ b/docs/examples/contrib/sqlalchemy/sqlalchemy_declarative_models.py @@ -1,16 +1,17 @@ +from __future__ import annotations + +import uuid from datetime import date -from typing import TYPE_CHECKING +from typing import List from uuid import UUID -from sqlalchemy import ForeignKey, select +from sqlalchemy import ForeignKey, func, select +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship from litestar import Litestar, get from litestar.contrib.sqlalchemy.base import UUIDAuditBase, UUIDBase -from litestar.contrib.sqlalchemy.plugins import AsyncSessionConfig, SQLAlchemyAsyncConfig, SQLAlchemyInitPlugin - -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from litestar.contrib.sqlalchemy.plugins import AsyncSessionConfig, SQLAlchemyAsyncConfig, SQLAlchemyPlugin # the SQLAlchemy base includes a declarative model for you to use in your models. @@ -18,7 +19,7 @@ class Author(UUIDBase): name: Mapped[str] dob: Mapped[date] - books: Mapped[list["Book"]] = relationship(back_populates="author", lazy="selectin") + books: Mapped[List[Book]] = relationship(back_populates="author", lazy="selectin") # The `AuditBase` class includes the same UUID` based primary key (`id`) and 2 @@ -32,19 +33,24 @@ class Book(UUIDAuditBase): session_config = AsyncSessionConfig(expire_on_commit=False) sqlalchemy_config = SQLAlchemyAsyncConfig( - connection_string="sqlite+aiosqlite:///test.sqlite", session_config=session_config + connection_string="sqlite+aiosqlite:///test.sqlite", session_config=session_config, create_all=True ) # Create 'async_session' dependency. -sqlalchemy_plugin = SQLAlchemyInitPlugin(config=sqlalchemy_config) async def on_startup() -> None: - """Initializes the database.""" - async with sqlalchemy_config.get_engine().begin() as conn: - await conn.run_sync(UUIDBase.metadata.create_all) + """Adds some dummy data if no data is present.""" + async with sqlalchemy_config.get_session() as session: + statement = select(func.count()).select_from(Author) + count = await session.execute(statement) + if not count.scalar(): + author_id = uuid.uuid4() + session.add(Author(name="Stephen King", dob=date(1954, 9, 21), id=author_id)) + session.add(Book(title="It", author_id=author_id)) + await session.commit() @get(path="/authors") -async def get_authors(db_session: "AsyncSession", db_engine: "AsyncEngine") -> list[Author]: +async def get_authors(db_session: AsyncSession, db_engine: AsyncEngine) -> List[Author]: """Interact with SQLAlchemy engine and session.""" return list(await db_session.scalars(select(Author))) @@ -52,5 +58,6 @@ async def get_authors(db_session: "AsyncSession", db_engine: "AsyncEngine") -> l app = Litestar( route_handlers=[get_authors], on_startup=[on_startup], - plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)], + debug=True, + plugins=[SQLAlchemyPlugin(config=sqlalchemy_config)], ) diff --git a/pyproject.toml b/pyproject.toml index 5e113632fe..881b4d291a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -395,6 +395,7 @@ known-first-party = ["litestar", "tests", "examples"] "docs/examples/application_hooks/before_send_hook.py" = ["UP006"] "docs/examples/contrib/sqlalchemy/plugins/**/*.*" = ["UP006"] "docs/examples/data_transfer_objects**/*.*" = ["UP006"] +"docs/examples/contrib/sqlalchemy/sqlalchemy_declarative_models.py" = ["UP006"] "litestar/_openapi/schema_generation/schema.py" = ["C901"] "litestar/exceptions/*.*" = ["N818"] "litestar/handlers/**/*.*" = ["N801"] diff --git a/tests/examples/test_contrib/test_sqlalchemy/test_sqlalchemy_examples.py b/tests/examples/test_contrib/test_sqlalchemy/test_sqlalchemy_examples.py new file mode 100644 index 0000000000..0aa531bd33 --- /dev/null +++ b/tests/examples/test_contrib/test_sqlalchemy/test_sqlalchemy_examples.py @@ -0,0 +1,14 @@ +import pytest + +from litestar.testing import TestClient + +pytestmark = pytest.mark.xdist_group("sqlalchemy_examples") + + +def test_sqlalchemy_declarative_models() -> None: + from docs.examples.contrib.sqlalchemy.sqlalchemy_declarative_models import app + + with TestClient(app) as client: + response = client.get("/authors") + assert response.status_code == 200 + assert len(response.json()) > 0