From 9537281e471243c03a06cae41af9ae3c6eb8b21d Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Thu, 14 Dec 2023 10:29:55 -0500 Subject: [PATCH 1/4] Switch to async access to the database --- src/nextline_rdb/pagination.py | 14 ++--- src/nextline_rdb/plugin.py | 24 ++++---- .../schema/pagination/connection.py | 13 +++-- src/nextline_rdb/schema/pagination/db.py | 8 +-- src/nextline_rdb/schema/query.py | 4 +- src/nextline_rdb/schema/types.py | 20 +++---- tests/pagination/conftest.py | 45 +++++++-------- tests/pagination/test_falsy_id.py | 13 +++-- tests/pagination/test_fixture.py | 5 +- tests/pagination/test_pagination.py | 21 +++---- tests/pagination/test_sort.py | 5 +- tests/schema/queries/test_history.py | 56 ------------------- tests/schema/queries/test_pagenation.py | 48 ++++++++-------- 13 files changed, 112 insertions(+), 164 deletions(-) diff --git a/src/nextline_rdb/pagination.py b/src/nextline_rdb/pagination.py index be8e1ad..745b887 100644 --- a/src/nextline_rdb/pagination.py +++ b/src/nextline_rdb/pagination.py @@ -1,11 +1,11 @@ from typing import NamedTuple, Optional, Type, TypeVar, cast from sqlalchemy import func, select -from sqlalchemy.orm import aliased +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import aliased, DeclarativeBase from sqlalchemy.sql.expression import literal from sqlalchemy.sql.selectable import Select -from . import models as db_models # import sqlparse @@ -24,9 +24,9 @@ class SortField(NamedTuple): _Id = TypeVar("_Id") -def load_models( - session, - Model: Type[db_models.Model], +async def load_models( + session: AsyncSession, + Model: Type[DeclarativeBase], id_field: str, *, sort: Optional[Sort] = None, @@ -45,12 +45,12 @@ def load_models( last=last, ) - models = session.scalars(stmt) + models = await session.scalars(stmt) return models def compose_statement( - Model: Type[db_models.Model], + Model: Type[DeclarativeBase], id_field: str, *, sort: Optional[Sort] = None, diff --git a/src/nextline_rdb/plugin.py b/src/nextline_rdb/plugin.py index 118f0a4..5933770 100644 --- a/src/nextline_rdb/plugin.py +++ b/src/nextline_rdb/plugin.py @@ -8,12 +8,12 @@ from nextline import Nextline from nextlinegraphql.hook import spec from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from . import models -from .db import DB +from .db import AsyncDB from .schema import Mutation, Query, Subscription -from .write import write_db +from .write import async_write_db HERE = Path(__file__).resolve().parent DEFAULT_CONFIG_PATH = HERE / 'default.toml' @@ -50,14 +50,14 @@ def schema(self) -> tuple[type, type | None, type | None]: @asynccontextmanager async def lifespan(self, context: Mapping) -> AsyncIterator[None]: nextline = context['nextline'] - self._db = DB(self._url) - with self._db: + self._db = AsyncDB(self._url) + async with self._db: await self._initialize_nextline(nextline) - async with write_db(nextline, self._db): + async with async_write_db(nextline, self._db): yield async def _initialize_nextline(self, nextline: Nextline) -> None: - run_no, script = self._last_run_no_and_script() + run_no, script = await self._last_run_no_and_script() if run_no is not None: run_no += 1 if run_no >= nextline._init_options.run_no_start_from: @@ -65,17 +65,17 @@ async def _initialize_nextline(self, nextline: Nextline) -> None: if script is not None: nextline._init_options.statement = script - def _last_run_no_and_script(self) -> tuple[Optional[int], Optional[str]]: - with self._db.session() as session: - last_run = self._last_run(session) + async def _last_run_no_and_script(self) -> tuple[Optional[int], Optional[str]]: + async with self._db.session() as session: + last_run = await self._last_run(session) if last_run is None: return None, None else: return last_run.run_no, last_run.script - def _last_run(self, session: Session) -> Optional[models.Run]: + async def _last_run(self, session: AsyncSession) -> Optional[models.Run]: stmt = select(models.Run, func.max(models.Run.run_no)) - if model := session.execute(stmt).scalar_one_or_none(): + if model := (await session.execute(stmt)).scalar_one_or_none(): return model else: logger = getLogger(__name__) diff --git a/src/nextline_rdb/schema/pagination/connection.py b/src/nextline_rdb/schema/pagination/connection.py index 522be8f..7ee85d0 100644 --- a/src/nextline_rdb/schema/pagination/connection.py +++ b/src/nextline_rdb/schema/pagination/connection.py @@ -4,7 +4,8 @@ Relay doc: https://relay.dev/graphql/connections.htm """ -from typing import Callable, Generic, Optional, TypeVar +from collections.abc import Callable, Coroutine +from typing import Any, Generic, Optional, TypeVar import strawberry from strawberry.types import Info @@ -32,9 +33,9 @@ class Connection(Generic[_T]): edges: list[Edge[_T]] -def query_connection( +async def query_connection( info: Info, - query_edges: Callable[..., list[Edge[_T]]], + query_edges: Callable[..., Coroutine[Any, Any, list[Edge[_T]]]], before: Optional[str] = None, after: Optional[str] = None, first: Optional[int] = None, @@ -49,19 +50,19 @@ def query_connection( if forward: if first is not None: first += 1 # add one for has_next_page - edges = query_edges(info=info, after=after, first=first) + edges = await query_edges(info=info, after=after, first=first) has_previous_page = not not after if has_next_page := len(edges) == first: edges = edges[:-1] elif backward: if last is not None: last += 1 # add one for has_previous_page - edges = query_edges(info=info, before=before, last=last) + edges = await query_edges(info=info, before=before, last=last) if has_previous_page := len(edges) == last: edges = edges[1:] has_next_page = not not before else: - edges = query_edges(info) + edges = await query_edges(info) has_previous_page = False has_next_page = False diff --git a/src/nextline_rdb/schema/pagination/db.py b/src/nextline_rdb/schema/pagination/db.py index 72e6ca8..c65403d 100644 --- a/src/nextline_rdb/schema/pagination/db.py +++ b/src/nextline_rdb/schema/pagination/db.py @@ -25,7 +25,7 @@ def decode_id(cursor: str) -> int: _T = TypeVar("_T") -def load_connection( +async def load_connection( info: Info, Model: Type[db_models.Model], id_field: str, @@ -43,7 +43,7 @@ def load_connection( create_node_from_model=create_node_from_model, ) - return query_connection( + return await query_connection( info, query_edges, before, @@ -53,7 +53,7 @@ def load_connection( ) -def load_edges( +async def load_edges( info: Info, Model: Type[db_models.Model], id_field: str, @@ -66,7 +66,7 @@ def load_edges( ) -> list[Edge[_T]]: session = info.context["session"] - models = load_models( + models = await load_models( session, Model, id_field, diff --git a/src/nextline_rdb/schema/query.py b/src/nextline_rdb/schema/query.py index f618a0e..291ffa3 100644 --- a/src/nextline_rdb/schema/query.py +++ b/src/nextline_rdb/schema/query.py @@ -24,8 +24,8 @@ class History: @strawberry.type class Query: @strawberry.field - def history(self, info: Info) -> History: + async def history(self, info: Info) -> History: db = info.context["db"] - with db.session() as session: + async with db.session() as session: info.context["session"] = session return History() diff --git a/src/nextline_rdb/schema/types.py b/src/nextline_rdb/schema/types.py index 2767074..5ff0977 100644 --- a/src/nextline_rdb/schema/types.py +++ b/src/nextline_rdb/schema/types.py @@ -12,7 +12,7 @@ from .pagination import Connection, load_connection -def query_connection_run( +async def query_connection_run( info: Info, before: Optional[str] = None, after: Optional[str] = None, @@ -21,10 +21,10 @@ def query_connection_run( ) -> Connection[RunHistory]: Model = db_models.Run NodeType = RunHistory - return query_connection(info, before, after, first, last, Model, NodeType) + return await query_connection(info, before, after, first, last, Model, NodeType) -def query_connection_trace( +async def query_connection_trace( info: Info, before: Optional[str] = None, after: Optional[str] = None, @@ -33,10 +33,10 @@ def query_connection_trace( ) -> Connection[TraceHistory]: Model = db_models.Trace NodeType = TraceHistory - return query_connection(info, before, after, first, last, Model, NodeType) + return await query_connection(info, before, after, first, last, Model, NodeType) -def query_connection_prompt( +async def query_connection_prompt( info: Info, before: Optional[str] = None, after: Optional[str] = None, @@ -45,10 +45,10 @@ def query_connection_prompt( ) -> Connection[PromptHistory]: Model = db_models.Prompt NodeType = PromptHistory - return query_connection(info, before, after, first, last, Model, NodeType) + return await query_connection(info, before, after, first, last, Model, NodeType) -def query_connection_stdout( +async def query_connection_stdout( info: Info, before: Optional[str] = None, after: Optional[str] = None, @@ -57,13 +57,13 @@ def query_connection_stdout( ) -> Connection[StdoutHistory]: Model = db_models.Stdout NodeType = StdoutHistory - return query_connection(info, before, after, first, last, Model, NodeType) + return await query_connection(info, before, after, first, last, Model, NodeType) _T = TypeVar("_T") -def query_connection( +async def query_connection( info: Info, before: Optional[str], after: Optional[str], @@ -77,7 +77,7 @@ def query_connection( create_node_from_model = NodeType.from_model # type: ignore - return load_connection( + return await load_connection( info, Model, id_field, diff --git a/tests/pagination/conftest.py b/tests/pagination/conftest.py index aeea4bf..6978193 100644 --- a/tests/pagination/conftest.py +++ b/tests/pagination/conftest.py @@ -1,29 +1,26 @@ +from collections.abc import AsyncIterator + import pytest -from sqlalchemy import create_engine, select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) from .models import Base, Entity @pytest.fixture -def session(db, sample): +async def session(db: async_sessionmaker, sample) -> AsyncIterator[AsyncSession]: del sample - with db() as y: + async with db() as y: yield y -def test_sample(db, sample): - del sample - Model = Entity - with db() as session: - stmt = select(Model) - models = session.scalars(stmt) - assert 10 == len(models.all()) - - @pytest.fixture -def sample(db): - with db.begin() as session: +async def sample(db: async_sessionmaker): + async with db.begin() as session: num = [3, 3, 3, 2, 2, 2, 1, 1, 1, 1] txt = ["AA", "BB", "AA", "AA", "BB", "AA", "AA", "BB", "AA", "BB"] for i in range(10): @@ -32,14 +29,18 @@ def sample(db): @pytest.fixture -def db(engine): - Base.metadata.create_all(bind=engine) - y = sessionmaker(autocommit=False, autoflush=False, bind=engine) +async def db(engine) -> async_sessionmaker: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + y = async_sessionmaker(bind=engine, expire_on_commit=False) return y @pytest.fixture -def engine(): - url = "sqlite:///:memory:?check_same_thread=false" - y = create_engine(url) - return y +async def engine() -> AsyncIterator[AsyncEngine]: + url = 'sqlite+aiosqlite://' + y = create_async_engine(url) + try: + yield y + finally: + await y.dispose() diff --git a/tests/pagination/test_falsy_id.py b/tests/pagination/test_falsy_id.py index 6b7ff12..70d14c0 100644 --- a/tests/pagination/test_falsy_id.py +++ b/tests/pagination/test_falsy_id.py @@ -1,4 +1,5 @@ import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from nextline_rdb.pagination import SortField, load_models @@ -21,10 +22,10 @@ @pytest.mark.parametrize("kwargs, expected", params) -def test_one(session, kwargs, expected): +async def test_one(session: AsyncSession, kwargs, expected): Model = Entity id_field = "id" - models = load_models(session, Model, id_field, **kwargs) + models = await load_models(session, Model, id_field, **kwargs) assert expected == [m.id for m in models] @@ -45,16 +46,16 @@ def test_one(session, kwargs, expected): @pytest.mark.parametrize("kwargs, expected", params) -def test_str(session, kwargs, expected): +async def test_str(session: AsyncSession, kwargs, expected): Model = Entity id_field = "txt" - models = load_models(session, Model, id_field, **kwargs) + models = await load_models(session, Model, id_field, **kwargs) assert expected == [getattr(m, id_field) for m in models] @pytest.fixture -def sample(db): - with db.begin() as session: +async def sample(db: async_sessionmaker): + async with db.begin() as session: num = [3, 3, 3, 2, 2, 2, 1, 1, 1, 1] txt = ["", "A", "B", "C", "D", "E", "F", "G", "H", "I"] for i in range(10): diff --git a/tests/pagination/test_fixture.py b/tests/pagination/test_fixture.py index 1c08958..c69d7bf 100644 --- a/tests/pagination/test_fixture.py +++ b/tests/pagination/test_fixture.py @@ -1,10 +1,11 @@ from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from .models import Entity -def test_sample(session): +async def test_sample(session: AsyncSession): Model = Entity stmt = select(Model) - models = session.scalars(stmt) + models = await session.scalars(stmt) assert 10 == len(models.all()) diff --git a/tests/pagination/test_pagination.py b/tests/pagination/test_pagination.py index b36f3cc..c6acf67 100644 --- a/tests/pagination/test_pagination.py +++ b/tests/pagination/test_pagination.py @@ -1,13 +1,14 @@ import pytest +from sqlalchemy.ext.asyncio import AsyncSession from nextline_rdb.pagination import load_models from .models import Entity -def test_all(session): +async def test_all(session: AsyncSession): Model = Entity - models = load_models(session, Model, "id") + models = await load_models(session, Model, "id") expected = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] assert expected == [m.id for m in models] @@ -23,9 +24,9 @@ def test_all(session): @pytest.mark.parametrize("kwargs, expected", params) -def test_forward(session, kwargs, expected): +async def test_forward(session: AsyncSession, kwargs, expected): Model = Entity - models = load_models(session, Model, "id", **kwargs) + models = await load_models(session, Model, "id", **kwargs) assert expected == [m.id for m in models] @@ -43,9 +44,9 @@ def test_forward(session, kwargs, expected): @pytest.mark.parametrize("kwargs, expected", params) -def test_forward_with_after(session, kwargs, expected): +async def test_forward_with_after(session: AsyncSession, kwargs, expected): Model = Entity - models = load_models(session, Model, "id", **kwargs) + models = await load_models(session, Model, "id", **kwargs) assert expected == [m.id for m in models] @@ -60,9 +61,9 @@ def test_forward_with_after(session, kwargs, expected): @pytest.mark.parametrize("kwargs, expected", params) -def test_backward(session, kwargs, expected): +async def test_backward(session: AsyncSession, kwargs, expected): Model = Entity - models = load_models(session, Model, "id", **kwargs) + models = await load_models(session, Model, "id", **kwargs) assert expected == [m.id for m in models] @@ -80,7 +81,7 @@ def test_backward(session, kwargs, expected): @pytest.mark.parametrize("kwargs, expected", params) -def test_backward_with_before(session, kwargs, expected): +async def test_backward_with_before(session: AsyncSession, kwargs, expected): Model = Entity - models = load_models(session, Model, "id", **kwargs) + models = await load_models(session, Model, "id", **kwargs) assert expected == [m.id for m in models] diff --git a/tests/pagination/test_sort.py b/tests/pagination/test_sort.py index 4d2fb62..2cb100b 100644 --- a/tests/pagination/test_sort.py +++ b/tests/pagination/test_sort.py @@ -1,4 +1,5 @@ import pytest +from sqlalchemy.ext.asyncio import AsyncSession from nextline_rdb.pagination import SortField, load_models @@ -65,8 +66,8 @@ @pytest.mark.parametrize("kwargs, expected", params) -def test_sort(session, kwargs, expected): +async def test_sort(session: AsyncSession, kwargs, expected): Model = Entity id_field = "id" - models = load_models(session, Model, id_field, **kwargs) + models = await load_models(session, Model, id_field, **kwargs) assert expected == [m.id for m in models] diff --git a/tests/schema/queries/test_history.py b/tests/schema/queries/test_history.py index 0a237d6..e11a82f 100644 --- a/tests/schema/queries/test_history.py +++ b/tests/schema/queries/test_history.py @@ -1,19 +1,7 @@ -from collections.abc import Callable - -import pytest -import strawberry from async_asgi_testclient import TestClient -from hypothesis import given, settings -from hypothesis import strategies as st from nextlinegraphql.plugins.ctrl.test import run_statement from nextlinegraphql.plugins.graphql.test import gql_request -from nextline_rdb.db import DB -from nextline_rdb.db.adb import AsyncDB -from nextline_rdb.models.strategies import st_model_run_list -from nextline_rdb.schema import Query -from nextline_rdb.utils import ensure_sync_url - from ..graphql import QUERY_HISTORY @@ -32,47 +20,3 @@ async def test_one(client: TestClient): assert run["endedAt"] assert run["script"] assert not run["exception"] - - -@given(st.data()) -@settings(max_examples=200) -async def test_st_model_run_lists( - tmp_url_factory: Callable[[], str], - data: st.DataObject, -) -> None: - max_size = 10 - runs = data.draw(st_model_run_list(generate_traces=True, max_size=max_size)) - - # ic(runs) - - schema = strawberry.Schema(query=Query) - - url = tmp_url_factory() - - # db_ = DB(ensure_sync_url(url)) - with DB(ensure_sync_url(url)) as db_: - with db_.session() as session: - pass - - async with AsyncDB(url, use_migration=False) as db: - async with db.session.begin() as session: - session.add_all(runs) - - # db_ = DB(ensure_sync_url(url)) - with DB(ensure_sync_url(url)) as db_: - # with db_.session() as session: - # pass - # resp = schema.execute_sync(QUERY_HISTORY, context_value={'db': db_}) - resp = await schema.execute(QUERY_HISTORY, context_value={'db': db_}) - # ic(resp) - assert not resp.errors - - -@pytest.fixture(scope='session') -def tmp_url_factory(tmp_path_factory: pytest.TempPathFactory) -> Callable[[], str]: - def factory() -> str: - dir = tmp_path_factory.mktemp('db') - url = f'sqlite+aiosqlite:///{dir}/db.sqlite' - return url - - return factory diff --git a/tests/schema/queries/test_pagenation.py b/tests/schema/queries/test_pagenation.py index cea34d8..01213a1 100644 --- a/tests/schema/queries/test_pagenation.py +++ b/tests/schema/queries/test_pagenation.py @@ -6,7 +6,7 @@ from async_asgi_testclient import TestClient from nextlinegraphql.plugins.graphql.test import gql_request, gql_request_response -from nextline_rdb import DB +from nextline_rdb import AsyncDB from nextline_rdb.models import Run from ..graphql import QUERY_HISTORY_RUNS @@ -323,7 +323,7 @@ async def assert_results(client: TestClient, variables, expected): edges = all_runs["edges"] # print(page_info) - print(edges) + # print(edges) assert expected_page_info == page_info @@ -381,25 +381,9 @@ async def test_error_forward_and_backward(sample, client, variables): @pytest.fixture -def sample(db: DB): - with db.session() as session: - with session.begin(): - for run_no in range(11, 111): - model = Run( - run_no=run_no, - state="running", - started_at=datetime.datetime.utcnow(), - ended_at=datetime.datetime.utcnow(), - script="pass", - ) - session.add(model) - - -@pytest.fixture -def sample_one(db: DB): - with db.session() as session: - with session.begin(): - run_no = 10 +async def sample(db: AsyncDB): + async with db.session.begin() as session: + for run_no in range(11, 111): model = Run( run_no=run_no, state="running", @@ -411,12 +395,26 @@ def sample_one(db: DB): @pytest.fixture -def sample_empty(db: DB): +async def sample_one(db: AsyncDB): + async with db.session.begin() as session: + run_no = 10 + model = Run( + run_no=run_no, + state="running", + started_at=datetime.datetime.utcnow(), + ended_at=datetime.datetime.utcnow(), + script="pass", + ) + session.add(model) + + +@pytest.fixture +def sample_empty(db: AsyncDB): del db @pytest.fixture -def app(db: DB): +def app(db: AsyncDB): # NOTE: Overriding the app fixture from conftest.py because it adds an # entry in the DB. The factory.create_app() needs to be refactored so this # override is not needed. @@ -445,7 +443,7 @@ async def get_context(self, request, response=None) -> Optional[Any]: @pytest.fixture -def db(): +async def db(): url = 'sqlite:///:memory:?check_same_thread=false' - with DB(url=url) as db: + async with AsyncDB(url=url) as db: yield db From c71085e60d157ee5ecb17ef096ec388efee771bb Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Thu, 14 Dec 2023 11:30:11 -0500 Subject: [PATCH 2/4] Update the default DB URL --- src/nextline_rdb/default.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nextline_rdb/default.toml b/src/nextline_rdb/default.toml index d6721b1..c6fed3f 100644 --- a/src/nextline_rdb/default.toml +++ b/src/nextline_rdb/default.toml @@ -4,7 +4,7 @@ # https://github.com/rochacbruno/learndynaconf/tree/main/configs [db] -url = "sqlite:///:memory:?check_same_thread=false" +url = "sqlite+aiosqlite://" [logging.loggers.nextline_rdb] handlers = ["default"] From ca7117c40bbfbfca0a68bf37017343bc36e2cc88 Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Thu, 14 Dec 2023 12:11:28 -0500 Subject: [PATCH 3/4] Switch to async in alembic migrations --- src/nextline_rdb/alembic.ini | 2 +- src/nextline_rdb/alembic/env.py | 38 ++++++++++++++++++++++----------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/nextline_rdb/alembic.ini b/src/nextline_rdb/alembic.ini index 4779647..428f944 100644 --- a/src/nextline_rdb/alembic.ini +++ b/src/nextline_rdb/alembic.ini @@ -56,7 +56,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # output_encoding = utf-8 # sqlalchemy.url = driver://user:pass@localhost/dbname -sqlalchemy.url = sqlite:///migration.sqlite3 +sqlalchemy.url = sqlite+aiosqlite:///migration.sqlite3 [post_write_hooks] diff --git a/src/nextline_rdb/alembic/env.py b/src/nextline_rdb/alembic/env.py index 6376b6c..787b875 100644 --- a/src/nextline_rdb/alembic/env.py +++ b/src/nextline_rdb/alembic/env.py @@ -1,8 +1,10 @@ +import asyncio import logging import logging.config from alembic import context -from sqlalchemy import create_engine +from sqlalchemy import Connection +from sqlalchemy.ext.asyncio import create_async_engine from nextline_rdb import models @@ -52,6 +54,27 @@ def run_migrations_offline() -> None: context.run_migrations() +def do_run_migrations(connection: Connection) -> None: + context.configure( + connection=connection, + target_metadata=target_metadata, + render_as_batch=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + assert url is not None + connectable = create_async_engine(url) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + def run_migrations_online() -> None: """Run migrations in 'online' mode. @@ -59,18 +82,7 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - assert url is not None - connectable = create_engine(url) - - with connectable.connect() as connection: - context.configure( - connection=connection, - target_metadata=target_metadata, - render_as_batch=True, - ) - - with context.begin_transaction(): - context.run_migrations() + asyncio.run(run_async_migrations()) if context.is_offline_mode(): From 64f0067e676e607b57a56d36c2ddbbd05db74d7f Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Thu, 14 Dec 2023 12:25:07 -0500 Subject: [PATCH 4/4] Update test fixtures --- tests/alembic/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/alembic/conftest.py b/tests/alembic/conftest.py index 313c6ac..c127a0b 100644 --- a/tests/alembic/conftest.py +++ b/tests/alembic/conftest.py @@ -16,7 +16,7 @@ def alembic_config() -> Config: @pytest.fixture def alembic_config_in_memory(alembic_config: Config) -> Config: config = alembic_config - url = 'sqlite://' + url = 'sqlite+aiosqlite://' config.set_main_option('sqlalchemy.url', url) return config @@ -27,6 +27,6 @@ def alembic_config_temp_sqlite( ) -> Config: config = alembic_config dir = tmp_path_factory.mktemp('db') - url = f'sqlite:///{dir}/db.sqlite' + url = f'sqlite+aiosqlite:///{dir}/db.sqlite' config.set_main_option('sqlalchemy.url', url) return config