diff --git a/src/dioptra/restapi/v1/auth/service.py b/src/dioptra/restapi/v1/auth/service.py index 855df7adf..0815d6ad0 100644 --- a/src/dioptra/restapi/v1/auth/service.py +++ b/src/dioptra/restapi/v1/auth/service.py @@ -19,15 +19,17 @@ import datetime import uuid -from typing import Any, cast +from typing import Any import structlog from flask_login import current_user, login_user, logout_user from injector import inject from structlog.stdlib import BoundLogger -from dioptra.restapi.db import db, models -from dioptra.restapi.v1.users.service import UserNameService, UserPasswordService +from dioptra.restapi.db.repository.utils import DeletionPolicy +from dioptra.restapi.db.unit_of_work import UnitOfWork +from dioptra.restapi.v1.users.errors import UserDoesNotExistError +from dioptra.restapi.v1.users.service import UserPasswordService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -37,20 +39,17 @@ class AuthService(object): @inject def __init__( - self, - user_name_service: UserNameService, - user_password_service: UserPasswordService, + self, user_password_service: UserPasswordService, uow: UnitOfWork ) -> None: """Initialize the authentication service. All arguments are provided via dependency injection. Args: - user_name_service: A UserNameService object. user_password_service: A UserPasswordService object. """ - self._user_name_service = user_name_service self._user_password_service = user_password_service + self._uow = uow def login( self, @@ -68,12 +67,12 @@ def login( A dictionary containing the login success message. """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - user = cast( - models.User, - self._user_name_service.get( - username=username, error_if_not_found=True, log=log - ), - ) + + user = self._uow.user_repo.get_by_name(username, DeletionPolicy.NOT_DELETED) + if not user: + log.debug("User not found", username=username) + raise UserDoesNotExistError + self._user_password_service.authenticate( password=password, user_password=str(user.password), @@ -82,8 +81,8 @@ def login( log=log, ) login_user(user, remember=True) - user.last_login_on = datetime.datetime.now(tz=datetime.timezone.utc) - db.session.commit() + with self._uow: + user.last_login_on = datetime.datetime.now(tz=datetime.timezone.utc) log.debug("Login successful", user_id=user.user_id) return {"status": "Login successful", "username": username} @@ -102,8 +101,8 @@ def logout(self, everywhere: bool, **kwargs) -> dict[str, Any]: username = current_user.username if everywhere: - current_user.alternative_id = uuid.uuid4() - db.session.commit() + with self._uow: + current_user.alternative_id = uuid.uuid4() logout_user() log.debug("Logout successful", user_id=user_id, everywhere=everywhere) diff --git a/src/dioptra/restapi/v1/groups/service.py b/src/dioptra/restapi/v1/groups/service.py index 9216711ea..ac13c916c 100644 --- a/src/dioptra/restapi/v1/groups/service.py +++ b/src/dioptra/restapi/v1/groups/service.py @@ -56,7 +56,6 @@ class GroupService(object): @inject def __init__( self, - group_name_service: GroupNameService, group_member_service: GroupMemberService, group_manager_service: GroupManagerService, uow: UnitOfWork, @@ -66,11 +65,10 @@ def __init__( All arguments are provided via dependency injection. Args: - group_name_service: A GroupNameService object. group_member_service: A GroupMemberService object. group_manager_service: A GroupManagerService object. + uow: A UnitOfWork instance """ - self._group_name_service = group_name_service self._group_member_service = group_member_service self._group_manager_service = group_manager_service self._uow = uow @@ -170,8 +168,6 @@ class GroupIdService(object): @inject def __init__( self, - group_service: GroupService, - group_name_service: GroupNameService, uow: UnitOfWork, ) -> None: """Initialize the group ID service. @@ -179,11 +175,8 @@ def __init__( All arguments are provided via dependency injection. Args: - group_service: A GroupService object. - group_name_service: A GroupNameService object. + uow: A UnitOfWork instance """ - self._group_service = group_service - self._group_name_service = group_name_service self._uow = uow def get( @@ -253,7 +246,7 @@ def modify( return None - if self._group_name_service.get(name, log=log) is not None: + if self._uow.group_repo.get_by_name(name, DeletionPolicy.ANY) is not None: log.debug("Group name already exists", name=name) raise GroupNameNotAvailableError @@ -292,45 +285,6 @@ def delete(self, group_id: int, **kwargs) -> dict[str, Any]: return {"status": "Success", "group": group.name} -class GroupNameService(object): - """The service methods used to manage a group by name.""" - - @inject - def __init__(self, uow: UnitOfWork): - self._uow = uow - - def get( - self, name: str, error_if_not_found: bool = False, **kwargs - ) -> models.Group | None: - """Fetch a group by its name. - - Args: - name: The name of the group. - error_if_not_found: If True, raise an error if the group is not found. - Defaults to False. - - Returns: - The group object if found, otherwise None. - - Raises: - GroupDoesNotExistError: If the group is not found and `error_if_not_found` - is True. - """ - log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.debug("Lookup group by name", name=name) - - group = self._uow.group_repo.get_by_name(name, DeletionPolicy.NOT_DELETED) - - if group is None: - if error_if_not_found: - log.debug("Group not found", name=name) - raise GroupDoesNotExistError - - return None - - return group - - class GroupMemberService(object): """The service methods used to manage a group's members.""" diff --git a/src/dioptra/restapi/v1/users/service.py b/src/dioptra/restapi/v1/users/service.py index 9e4112b8b..0454b5b28 100644 --- a/src/dioptra/restapi/v1/users/service.py +++ b/src/dioptra/restapi/v1/users/service.py @@ -31,7 +31,7 @@ from dioptra.restapi.db.repository.utils import DeletionPolicy from dioptra.restapi.db.unit_of_work import UnitOfWork from dioptra.restapi.errors import BackendDatabaseError -from dioptra.restapi.v1.groups.service import GroupMemberService, GroupNameService +from dioptra.restapi.v1.groups.service import GroupMemberService from dioptra.restapi.v1.plugin_parameter_types.service import ( BuiltinPluginParameterTypeService, ) @@ -71,8 +71,6 @@ class UserService(object): def __init__( self, user_password_service: UserPasswordService, - user_name_service: UserNameService, - group_name_service: GroupNameService, group_member_service: GroupMemberService, builtin_plugin_parameter_type_service: BuiltinPluginParameterTypeService, uow: UnitOfWork, @@ -83,16 +81,12 @@ def __init__( Args: user_password_service: A UserPasswordService object. - user_name_service: A UserNameService object. - group_name_service: A GroupNameService object. group_member_service: A GroupMemberService object. builtin_plugin_parameter_type_service: A BuiltinPluginParameterTypeService object. uow: A UnitOfWork instance """ self._user_password_service = user_password_service - self._user_name_service = user_name_service - self._group_name_service = group_name_service self._group_member_service = group_member_service self._builtin_plugin_parameter_type_service = ( builtin_plugin_parameter_type_service @@ -151,7 +145,7 @@ def create( ) if commit: - db.session.commit() + self._uow.commit() log.debug("User registration successful", user_id=new_user.user_id) return new_user @@ -213,37 +207,6 @@ def get( return users, total_num_users - def _get_user_by_email( - self, email_address: str, error_if_not_found: bool = False, **kwargs - ) -> models.User | None: - """Lookup a user by email address. - - Args: - email_address: The email address of the user. - error_if_not_found: If True, raise an error if the user is not found. - Defaults to False. - - Returns: - The user object if found, otherwise None. - - Raises: - UserDoesNotExistError: If the user is not found and `error_if_not_found` - is True. - """ - log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.debug("Lookup user account by email", email_address=email_address) - - user = self._uow.user_repo.get_by_email(email_address, DeletionPolicy.ANY) - - if user is None: - if error_if_not_found: - log.debug("User not found", email_address=email_address) - raise UserDoesNotExistError - - return None - - return user - def _create_or_get_default_group( self, user: models.User, @@ -284,6 +247,7 @@ def __init__( Args: user_password_service: A UserPasswordService object. + uow: A UnitOfWork instance """ self._user_password_service = user_password_service self._uow = uow @@ -371,6 +335,7 @@ def __init__( Args: user_id_service: A UserIdService object. user_password_service: A UserPasswordService object. + uow: A UnitOfWork instance """ self._user_id_service = user_id_service self._user_password_service = user_password_service @@ -479,57 +444,6 @@ def change_password( ) -class UserNameService(object): - """The service methods used to register and manage user accounts by username.""" - - @inject - def __init__( - self, - user_password_service: UserPasswordService, - uow: UnitOfWork, - ) -> None: - """Initialize the user name service. - - All arguments are provided via dependency injection. - - Args: - user_password_service: A UserPasswordService object. - """ - self._user_password_service = user_password_service - self._uow = uow - - def get( - self, username: str, error_if_not_found: bool = False, **kwargs - ) -> models.User | None: - """Fetch a user by its username. - - Args: - username: The username of the user. - error_if_not_found: If True, raise an error if the user is not found. - Defaults to False. - - Returns: - The user object if found, otherwise None. - - Raises: - UserDoesNotExistError: If the user is not found and `error_if_not_found` - is True. - """ - log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.debug("Lookup user account by unique username", username=username) - - user = self._uow.user_repo.get_by_name(username, DeletionPolicy.NOT_DELETED) - - if user is None: - if error_if_not_found: - log.debug("User not found", username=username) - raise UserDoesNotExistError - - return None - - return user - - class UserPasswordService(object): """The service methods used to manage user passwords.""" @@ -545,6 +459,7 @@ def __init__( Args: password_service: A PasswordService object. + uow: A UnitOfWork instance """ self._password_service = password_service self._uow = uow