diff --git a/src/conftest.py b/src/conftest.py index d7d0a77..e538154 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -42,6 +42,7 @@ def setup_factory_session(test_config_obj): def mock_container() -> Generator[Container, None, None]: mocked_container = copy(base_mock_container) mocked_container.db_client = MagicMock() + mocked_container.repository_factory = MagicMock() mocked_container.init_resources() yield mocked_container mocked_container.unwire() diff --git a/src/containers.py b/src/containers.py index 3d7b1ce..04fb924 100644 --- a/src/containers.py +++ b/src/containers.py @@ -5,6 +5,7 @@ from logging.config import dictConfig from dependency_injector.containers import DeclarativeContainer +from dependency_injector.providers import Callable from dependency_injector.providers import Configuration from dependency_injector.providers import Factory from dependency_injector.providers import Resource @@ -73,19 +74,18 @@ class Container(DeclarativeContainer): logging = Resource(dictConfig, config=config.logger) db_client = Singleton(Database, db_url=config.db.async_database_uri) + repository_factory = Callable(BaseRepository, session_factory=db_client.provided.get_session) - user_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=User) - user_service = Factory(UserService, repository=user_repository) + user_service = Factory(UserService, repository_factory=repository_factory, model=User) - quest_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=Quest) - user_quest_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=UserQuest) quest_service = Factory( - QuestService, quest_repository=quest_repository, user_quest_repository=user_quest_repository + QuestService, repository_factory=repository_factory, quest_model=Quest, user_quest_model=UserQuest ) - xp_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=ExperienceTransaction) - xp_service = Factory(ExperienceTransactionService, repository=xp_repository) + xp_service = Factory( + ExperienceTransactionService, repository_factory=repository_factory, model=ExperienceTransaction + ) - menu_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=Menu) - menu_item_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=MenuItem) - tavern_service = Factory(TavernService, menu_repository=menu_repository, menu_item_repository=menu_item_repository) + tavern_service = Factory( + TavernService, repository_factory=repository_factory, menu_model=Menu, menu_item_model=MenuItem + ) diff --git a/src/services.py b/src/services.py index 53231c3..cdd59c9 100644 --- a/src/services.py +++ b/src/services.py @@ -1,9 +1,11 @@ +from collections.abc import Callable from logging import Logger from logging import getLogger from src.helpers.sqlalchemy_helpers import QueryArgs from src.models import User from src.repositories import BaseRepository +from src.typeshed import BaseModelType from src.typeshed import RepositoryHandler @@ -17,17 +19,25 @@ def __init__(self) -> None: class SingleRepoService(BaseService): _repository: BaseRepository - def __init__(self, repository: BaseRepository) -> None: + def __init__( + self, + repository_factory: Callable[[type[BaseModelType]], BaseRepository[BaseModelType]], + model: type[BaseModelType], + ) -> None: super().__init__() - self._repository = repository + self._repository = repository_factory(model) class MultiRepoService(BaseService): _repositories: RepositoryHandler - def __init__(self, **repositories: BaseRepository) -> None: + def __init__( + self, + repository_factory: Callable[[type[BaseModelType]], BaseRepository[BaseModelType]], + **models: type[BaseModelType], + ) -> None: super().__init__() - self._repositories = RepositoryHandler(**repositories) + self._repositories = RepositoryHandler(repository_factory, **models) class UserService(SingleRepoService): diff --git a/src/tests/bot/test_services.py b/src/tests/bot/test_services.py index 55d9fb6..08b0022 100644 --- a/src/tests/bot/test_services.py +++ b/src/tests/bot/test_services.py @@ -20,7 +20,10 @@ async def test_accept_quest_if_available(self, user): get_first=AsyncMock(return_value=MagicMock(users=[])), session=AsyncMock(commit=AsyncMock()) ) mock_user_quest_repo = AsyncMock(get_count=AsyncMock(return_value=0), add=AsyncMock()) - quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repo) + mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repo]) + quest_service = QuestService( + repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock() + ) # Act res = await quest_service.accept_quest_if_available(user, "Quest title") # Assert @@ -30,7 +33,10 @@ async def test_accept_quest_if_available(self, user): async def test_quest_dne(self, user): # Arrange mock_quest_repository = AsyncMock(get_first=AsyncMock(return_value=None)) - quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=AsyncMock()) + mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, AsyncMock()]) + quest_service = QuestService( + repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock() + ) # Act & Assert with pytest.raises(QuestDNE): await quest_service.accept_quest_if_available(user, "Quest title") @@ -42,7 +48,10 @@ async def test_quest_already_accepted(self, user): get_first=AsyncMock(return_value=quest), session=AsyncMock(commit=AsyncMock()) ) mock_user_quest_repo = AsyncMock(get_count=AsyncMock(return_value=1)) - quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repo) + mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repo]) + quest_service = QuestService( + repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock() + ) # Act & Assert with pytest.raises(QuestAlreadyAccepted): await quest_service.accept_quest_if_available(user, "Quest title") @@ -58,8 +67,9 @@ async def test_quest_completed(self, user, max_completion_count): mock_user_quest_repository = AsyncMock( get_count=AsyncMock(return_value=1), get_first=AsyncMock(return_value=user_quest) ) + mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repository]) quest_service = QuestService( - quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repository + repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock() ) # Act res = await quest_service.complete_quest_if_available(user, "Quest title") @@ -69,7 +79,10 @@ async def test_quest_completed(self, user, max_completion_count): async def test_cannot_complete_nonexistent_quest(self, user): # Arrange mock_quest_repository = AsyncMock(get_first=AsyncMock(return_value=None)) - quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=AsyncMock()) + mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, AsyncMock()]) + quest_service = QuestService( + repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock() + ) # Act & Assert with pytest.raises(QuestDNE): await quest_service.complete_quest_if_available(user, "Quest Title") @@ -80,7 +93,10 @@ async def test_cannot_complete_unaccepted_quest(self, user): get_first=AsyncMock(return_value=MagicMock()), session=AsyncMock(commit=AsyncMock()) ) mock_user_quest_repo = AsyncMock(get_first=AsyncMock(return_value=None)) - quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repo) + mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repo]) + quest_service = QuestService( + repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock() + ) # Act & Assert with pytest.raises(QuestNotAccepted): await quest_service.complete_quest_if_available(user, "Quest Title") @@ -92,8 +108,9 @@ async def test_max_completion_count_reached(self, user): get_first=AsyncMock(return_value=quest), session=AsyncMock(commit=AsyncMock()) ) mock_user_quest_repository = AsyncMock(get_count=AsyncMock(return_value=1), get_first=AsyncMock()) + mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repository]) quest_service = QuestService( - quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repository + repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock() ) # Act & Assert with pytest.raises(MaxQuestCompletionReached): diff --git a/src/tests/tavern/test_services.py b/src/tests/tavern/test_services.py index b9107b8..0ae5d40 100644 --- a/src/tests/tavern/test_services.py +++ b/src/tests/tavern/test_services.py @@ -1,4 +1,5 @@ from unittest.mock import AsyncMock +from unittest.mock import MagicMock import pytest @@ -22,8 +23,7 @@ class TestDeleteMenuItem: ) async def test_no_items_found(self, faker, menu_items, item_name, day_of_week): # Arrange - menu_item_repo = AsyncMock() - tavern_service = TavernService(menu_item_repository=menu_item_repo) + tavern_service = TavernService(repository_factory=AsyncMock(), menu_item_model=MagicMock()) menu = Menu( server_id=faker.random_number(digits=10, fix_len=True), start_date=faker.date_object(), items=menu_items ) @@ -36,7 +36,8 @@ async def test_no_items_found(self, faker, menu_items, item_name, day_of_week): async def test_item_deleted(self, faker, day_of_week): # Arrange menu_item_repository = AsyncMock(delete=AsyncMock()) - tavern_service = TavernService(menu_item_repository=menu_item_repository) + mock_repository_factory = MagicMock(return_value=menu_item_repository) + tavern_service = TavernService(repository_factory=mock_repository_factory, menu_item_model=MagicMock()) menu_item = MenuItem(food="Food", day_of_the_week=DayOfWeek.MONDAY) menu = Menu( server_id=faker.random_number(digits=10, fix_len=True), start_date=faker.date_object(), items=[menu_item] diff --git a/src/typeshed.py b/src/typeshed.py index 993d975..4c4b76d 100644 --- a/src/typeshed.py +++ b/src/typeshed.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from collections.abc import Sequence from dataclasses import dataclass from dataclasses import field @@ -66,6 +67,10 @@ class MixinData(BaseModel): class RepositoryHandler: - def __init__(self, **repositories: "BaseRepository") -> None: - for key, repository in repositories.items(): - setattr(self, key.removesuffix("_repository"), repository) + def __init__( + self, + repository_factory: Callable[[type[BaseModelType]], "BaseRepository[BaseModelType]"], + **models: type[BaseModelType] + ) -> None: + for key, model in models.items(): + setattr(self, key.removesuffix("_model"), repository_factory(model))