Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor repository generation within core container to re-use a call… #44

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def setup_factory_session(test_config_obj):
def mock_container() -> Generator[Container, None, None]:
mocked_container = copy(base_mock_container)
mocked_container.db_client = MagicMock()
mocked_container.repository_factory = MagicMock()
mocked_container.init_resources()
yield mocked_container
mocked_container.unwire()
Expand Down
20 changes: 10 additions & 10 deletions src/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from logging.config import dictConfig

from dependency_injector.containers import DeclarativeContainer
from dependency_injector.providers import Callable
from dependency_injector.providers import Configuration
from dependency_injector.providers import Factory
from dependency_injector.providers import Resource
Expand Down Expand Up @@ -73,19 +74,18 @@ class Container(DeclarativeContainer):
logging = Resource(dictConfig, config=config.logger)

db_client = Singleton(Database, db_url=config.db.async_database_uri)
repository_factory = Callable(BaseRepository, session_factory=db_client.provided.get_session)

user_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=User)
user_service = Factory(UserService, repository=user_repository)
user_service = Factory(UserService, repository_factory=repository_factory, model=User)

quest_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=Quest)
user_quest_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=UserQuest)
quest_service = Factory(
QuestService, quest_repository=quest_repository, user_quest_repository=user_quest_repository
QuestService, repository_factory=repository_factory, quest_model=Quest, user_quest_model=UserQuest
)

xp_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=ExperienceTransaction)
xp_service = Factory(ExperienceTransactionService, repository=xp_repository)
xp_service = Factory(
ExperienceTransactionService, repository_factory=repository_factory, model=ExperienceTransaction
)

menu_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=Menu)
menu_item_repository = Factory(BaseRepository, session_factory=db_client.provided.get_session, model=MenuItem)
tavern_service = Factory(TavernService, menu_repository=menu_repository, menu_item_repository=menu_item_repository)
tavern_service = Factory(
TavernService, repository_factory=repository_factory, menu_model=Menu, menu_item_model=MenuItem
)
18 changes: 14 additions & 4 deletions src/services.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections.abc import Callable
from logging import Logger
from logging import getLogger

from src.helpers.sqlalchemy_helpers import QueryArgs
from src.models import User
from src.repositories import BaseRepository
from src.typeshed import BaseModelType
from src.typeshed import RepositoryHandler


Expand All @@ -17,17 +19,25 @@ def __init__(self) -> None:
class SingleRepoService(BaseService):
_repository: BaseRepository

def __init__(self, repository: BaseRepository) -> None:
def __init__(
self,
repository_factory: Callable[[type[BaseModelType]], BaseRepository[BaseModelType]],
model: type[BaseModelType],
) -> None:
super().__init__()
self._repository = repository
self._repository = repository_factory(model)


class MultiRepoService(BaseService):
_repositories: RepositoryHandler

def __init__(self, **repositories: BaseRepository) -> None:
def __init__(
self,
repository_factory: Callable[[type[BaseModelType]], BaseRepository[BaseModelType]],
**models: type[BaseModelType],
) -> None:
super().__init__()
self._repositories = RepositoryHandler(**repositories)
self._repositories = RepositoryHandler(repository_factory, **models)


class UserService(SingleRepoService):
Expand Down
31 changes: 24 additions & 7 deletions src/tests/bot/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ async def test_accept_quest_if_available(self, user):
get_first=AsyncMock(return_value=MagicMock(users=[])), session=AsyncMock(commit=AsyncMock())
)
mock_user_quest_repo = AsyncMock(get_count=AsyncMock(return_value=0), add=AsyncMock())
quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repo)
mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repo])
quest_service = QuestService(
repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock()
)
# Act
res = await quest_service.accept_quest_if_available(user, "Quest title")
# Assert
Expand All @@ -30,7 +33,10 @@ async def test_accept_quest_if_available(self, user):
async def test_quest_dne(self, user):
# Arrange
mock_quest_repository = AsyncMock(get_first=AsyncMock(return_value=None))
quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=AsyncMock())
mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, AsyncMock()])
quest_service = QuestService(
repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock()
)
# Act & Assert
with pytest.raises(QuestDNE):
await quest_service.accept_quest_if_available(user, "Quest title")
Expand All @@ -42,7 +48,10 @@ async def test_quest_already_accepted(self, user):
get_first=AsyncMock(return_value=quest), session=AsyncMock(commit=AsyncMock())
)
mock_user_quest_repo = AsyncMock(get_count=AsyncMock(return_value=1))
quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repo)
mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repo])
quest_service = QuestService(
repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock()
)
# Act & Assert
with pytest.raises(QuestAlreadyAccepted):
await quest_service.accept_quest_if_available(user, "Quest title")
Expand All @@ -58,8 +67,9 @@ async def test_quest_completed(self, user, max_completion_count):
mock_user_quest_repository = AsyncMock(
get_count=AsyncMock(return_value=1), get_first=AsyncMock(return_value=user_quest)
)
mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repository])
quest_service = QuestService(
quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repository
repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock()
)
# Act
res = await quest_service.complete_quest_if_available(user, "Quest title")
Expand All @@ -69,7 +79,10 @@ async def test_quest_completed(self, user, max_completion_count):
async def test_cannot_complete_nonexistent_quest(self, user):
# Arrange
mock_quest_repository = AsyncMock(get_first=AsyncMock(return_value=None))
quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=AsyncMock())
mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, AsyncMock()])
quest_service = QuestService(
repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock()
)
# Act & Assert
with pytest.raises(QuestDNE):
await quest_service.complete_quest_if_available(user, "Quest Title")
Expand All @@ -80,7 +93,10 @@ async def test_cannot_complete_unaccepted_quest(self, user):
get_first=AsyncMock(return_value=MagicMock()), session=AsyncMock(commit=AsyncMock())
)
mock_user_quest_repo = AsyncMock(get_first=AsyncMock(return_value=None))
quest_service = QuestService(quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repo)
mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repo])
quest_service = QuestService(
repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock()
)
# Act & Assert
with pytest.raises(QuestNotAccepted):
await quest_service.complete_quest_if_available(user, "Quest Title")
Expand All @@ -92,8 +108,9 @@ async def test_max_completion_count_reached(self, user):
get_first=AsyncMock(return_value=quest), session=AsyncMock(commit=AsyncMock())
)
mock_user_quest_repository = AsyncMock(get_count=AsyncMock(return_value=1), get_first=AsyncMock())
mock_repository_factory = MagicMock(side_effect=[mock_quest_repository, mock_user_quest_repository])
quest_service = QuestService(
quest_repository=mock_quest_repository, user_quest_repository=mock_user_quest_repository
repository_factory=mock_repository_factory, quest_model=MagicMock(), user_quest_model=MagicMock()
)
# Act & Assert
with pytest.raises(MaxQuestCompletionReached):
Expand Down
7 changes: 4 additions & 3 deletions src/tests/tavern/test_services.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

import pytest

Expand All @@ -22,8 +23,7 @@ class TestDeleteMenuItem:
)
async def test_no_items_found(self, faker, menu_items, item_name, day_of_week):
# Arrange
menu_item_repo = AsyncMock()
tavern_service = TavernService(menu_item_repository=menu_item_repo)
tavern_service = TavernService(repository_factory=AsyncMock(), menu_item_model=MagicMock())
menu = Menu(
server_id=faker.random_number(digits=10, fix_len=True), start_date=faker.date_object(), items=menu_items
)
Expand All @@ -36,7 +36,8 @@ async def test_no_items_found(self, faker, menu_items, item_name, day_of_week):
async def test_item_deleted(self, faker, day_of_week):
# Arrange
menu_item_repository = AsyncMock(delete=AsyncMock())
tavern_service = TavernService(menu_item_repository=menu_item_repository)
mock_repository_factory = MagicMock(return_value=menu_item_repository)
tavern_service = TavernService(repository_factory=mock_repository_factory, menu_item_model=MagicMock())
menu_item = MenuItem(food="Food", day_of_the_week=DayOfWeek.MONDAY)
menu = Menu(
server_id=faker.random_number(digits=10, fix_len=True), start_date=faker.date_object(), items=[menu_item]
Expand Down
11 changes: 8 additions & 3 deletions src/typeshed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Callable
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import field
Expand Down Expand Up @@ -66,6 +67,10 @@ class MixinData(BaseModel):


class RepositoryHandler:
def __init__(self, **repositories: "BaseRepository") -> None:
for key, repository in repositories.items():
setattr(self, key.removesuffix("_repository"), repository)
def __init__(
self,
repository_factory: Callable[[type[BaseModelType]], "BaseRepository[BaseModelType]"],
**models: type[BaseModelType]
) -> None:
for key, model in models.items():
setattr(self, key.removesuffix("_model"), repository_factory(model))