From 6b0d6f3ee3ded95d9b7dc086724ec02eeca96e2e Mon Sep 17 00:00:00 2001 From: Nikita Semenov Date: Thu, 26 Dec 2024 12:18:40 +0300 Subject: [PATCH 1/2] fix(sqla_factory): added an async context manager in SQLAASyncPersistence --- polyfactory/factories/sqlalchemy_factory.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 4b15ac52..f874d4d3 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -52,16 +52,18 @@ def __init__(self, session: AsyncSession) -> None: self.session = session async def save(self, data: T) -> T: - self.session.add(data) - await self.session.commit() - await self.session.refresh(data) + async with self.session as session: + session.add(data) + await session.commit() + await session.refresh(data) return data async def save_many(self, data: list[T]) -> list[T]: - self.session.add_all(data) - await self.session.commit() - for batch_item in data: - await self.session.refresh(batch_item) + async with self.session as session: + session.add_all(data) + await session.commit() + for batch_item in data: + await session.refresh(batch_item) return data From ff81dc1bdffdca489b5d8015fe756096d8d39353 Mon Sep 17 00:00:00 2001 From: Nikita Semenov Date: Fri, 27 Dec 2024 15:07:20 +0300 Subject: [PATCH 2/2] test(sqla_factory): refactored some tests after adding an async context manager --- .../test_sqlalchemy_factory_common.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 9d47366e..5c8084ce 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -18,6 +18,7 @@ func, inspect, orm, + select, text, types, ) @@ -343,13 +344,15 @@ class Factory(SQLAlchemyFactory[AsyncModel]): __async_session__ = session_config(session) __model__ = AsyncModel - result = await Factory.create_async() - assert inspect(result).persistent # type: ignore[union-attr] + instance = await Factory.create_async() + result = await session.scalar(select(AsyncModel).where(AsyncModel.id == instance.id)) + assert result batch_result = await Factory.create_batch_async(size=2) assert len(batch_result) == 2 for batch_item in batch_result: - assert inspect(batch_item).persistent # type: ignore[union-attr] + result = await session.scalar(select(AsyncModel).where(AsyncModel.id == batch_item.id)) + assert result @pytest.mark.parametrize( @@ -392,8 +395,9 @@ class Factory(SQLAlchemyFactory[AsyncRefreshModel]): test_int = Ignore() test_bool = Ignore() - result = await Factory.create_async() - assert inspect(result).persistent # type: ignore[union-attr] + instance = await Factory.create_async() + result = await session.scalar(select(AsyncRefreshModel).where(AsyncRefreshModel.id == instance.id)) + assert result assert result.test_datetime is not None assert isinstance(result.test_datetime, datetime) assert result.test_str == "test_str"