From 729d838355f52a86ac61a7a933dac88209c4c2ed Mon Sep 17 00:00:00 2001 From: Nikita Semenov Date: Fri, 27 Dec 2024 15:07:20 +0300 Subject: [PATCH] test(sqla_factory): refactored some tests after adding an async context manager --- .../test_sqlalchemy_factory_common.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 9d47366e..5c8084ce 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -18,6 +18,7 @@ func, inspect, orm, + select, text, types, ) @@ -343,13 +344,15 @@ class Factory(SQLAlchemyFactory[AsyncModel]): __async_session__ = session_config(session) __model__ = AsyncModel - result = await Factory.create_async() - assert inspect(result).persistent # type: ignore[union-attr] + instance = await Factory.create_async() + result = await session.scalar(select(AsyncModel).where(AsyncModel.id == instance.id)) + assert result batch_result = await Factory.create_batch_async(size=2) assert len(batch_result) == 2 for batch_item in batch_result: - assert inspect(batch_item).persistent # type: ignore[union-attr] + result = await session.scalar(select(AsyncModel).where(AsyncModel.id == batch_item.id)) + assert result @pytest.mark.parametrize( @@ -392,8 +395,9 @@ class Factory(SQLAlchemyFactory[AsyncRefreshModel]): test_int = Ignore() test_bool = Ignore() - result = await Factory.create_async() - assert inspect(result).persistent # type: ignore[union-attr] + instance = await Factory.create_async() + result = await session.scalar(select(AsyncRefreshModel).where(AsyncRefreshModel.id == instance.id)) + assert result assert result.test_datetime is not None assert isinstance(result.test_datetime, datetime) assert result.test_str == "test_str"