From 614345f1c48084b63b590988f7d3b83613bddcb3 Mon Sep 17 00:00:00 2001 From: abyrne Date: Wed, 29 Nov 2023 14:27:46 -0500 Subject: [PATCH 1/6] feat: add group and group_membership tables --- src/dioptra/restapi/dependencies.py | 10 + src/dioptra/restapi/group/__init__.py | 28 + src/dioptra/restapi/group/controller.py | 118 +++++ src/dioptra/restapi/group/dependencies.py | 51 ++ src/dioptra/restapi/group/errors.py | 44 ++ src/dioptra/restapi/group/model.py | 99 ++++ src/dioptra/restapi/group/routes.py | 41 ++ src/dioptra/restapi/group/schema.py | 80 +++ src/dioptra/restapi/group/service.py | 136 +++++ .../restapi/group_membership/__init__.py | 28 + .../restapi/group_membership/controller.py | 104 ++++ .../restapi/group_membership/dependencies.py | 51 ++ .../restapi/group_membership/errors.py | 46 ++ src/dioptra/restapi/group_membership/model.py | 67 +++ .../restapi/group_membership/routes.py | 42 ++ .../restapi/group_membership/schema.py | 88 ++++ .../restapi/group_membership/service.py | 106 ++++ src/dioptra/restapi/models.py | 4 + src/dioptra/restapi/routes.py | 4 + tests/unit/restapi/group/__init__.py | 16 + tests/unit/restapi/group/test_group.py | 477 ++++++++++++++++++ 21 files changed, 1640 insertions(+) create mode 100644 src/dioptra/restapi/group/__init__.py create mode 100644 src/dioptra/restapi/group/controller.py create mode 100644 src/dioptra/restapi/group/dependencies.py create mode 100644 src/dioptra/restapi/group/errors.py create mode 100644 src/dioptra/restapi/group/model.py create mode 100644 src/dioptra/restapi/group/routes.py create mode 100644 src/dioptra/restapi/group/schema.py create mode 100644 src/dioptra/restapi/group/service.py create mode 100644 src/dioptra/restapi/group_membership/__init__.py create mode 100644 src/dioptra/restapi/group_membership/controller.py create mode 100644 src/dioptra/restapi/group_membership/dependencies.py create mode 100644 src/dioptra/restapi/group_membership/errors.py create mode 100644 src/dioptra/restapi/group_membership/model.py create mode 100644 src/dioptra/restapi/group_membership/routes.py create mode 100644 src/dioptra/restapi/group_membership/schema.py create mode 100644 src/dioptra/restapi/group_membership/service.py create mode 100644 tests/unit/restapi/group/__init__.py create mode 100644 tests/unit/restapi/group/test_group.py diff --git a/src/dioptra/restapi/dependencies.py b/src/dioptra/restapi/dependencies.py index c61bc8f0e..82d290182 100644 --- a/src/dioptra/restapi/dependencies.py +++ b/src/dioptra/restapi/dependencies.py @@ -29,6 +29,10 @@ def bind_dependencies(binder: Binder) -> None: binder: A :py:class:`~injector.Binder` object. """ from .experiment import bind_dependencies as attach_experiment_dependencies + from .group import bind_dependencies as attach_group_dependencies + from .group_membership import ( + bind_dependencies as attach_group_membership_dependencies, + ) from .job import bind_dependencies as attach_job_dependencies from .queue import bind_dependencies as attach_job_queue_dependencies from .task_plugin import bind_dependencies as attach_task_plugin_dependencies @@ -40,6 +44,8 @@ def bind_dependencies(binder: Binder) -> None: attach_job_queue_dependencies(binder) attach_task_plugin_dependencies(binder) attach_user_dependencies(binder) + attach_group_dependencies(binder) + attach_group_membership_dependencies(binder) def register_providers(modules: List[Callable[..., Any]]) -> None: @@ -54,6 +60,8 @@ def register_providers(modules: List[Callable[..., Any]]) -> None: from .queue import register_providers as attach_job_queue_providers from .task_plugin import register_providers as attach_task_plugin_providers from .user import register_providers as attach_user_providers + from .group import register_routes as attach_group_providers + from .group_membership import register_routes as attach_group_membership_providers # Append modules to list attach_experiment_providers(modules) @@ -61,3 +69,5 @@ def register_providers(modules: List[Callable[..., Any]]) -> None: attach_job_queue_providers(modules) attach_task_plugin_providers(modules) attach_user_providers(modules) + attach_group_providers(modules) + attach_group_membership_providers(modules) diff --git a/src/dioptra/restapi/group/__init__.py b/src/dioptra/restapi/group/__init__.py new file mode 100644 index 000000000..543416304 --- /dev/null +++ b/src/dioptra/restapi/group/__init__.py @@ -0,0 +1,28 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The group endpoint subpackage.""" + +from .dependencies import bind_dependencies, register_providers +from .errors import register_error_handlers +from .routes import register_routes + +__all__ = [ + "bind_dependencies", + "register_error_handlers", + "register_providers", + "register_routes", +] diff --git a/src/dioptra/restapi/group/controller.py b/src/dioptra/restapi/group/controller.py new file mode 100644 index 000000000..7d94f7017 --- /dev/null +++ b/src/dioptra/restapi/group/controller.py @@ -0,0 +1,118 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The module defining the group endpoints.""" +from __future__ import annotations + +import uuid +from typing import Any, List, Optional + +import structlog +from flask import request +from flask_accepts import accepts, responds +from flask_restx import Namespace, Resource +from injector import inject +from structlog.stdlib import BoundLogger + +from dioptra.restapi.utils import slugify + +from .errors import GroupDoesNotExistError +from .model import Group +from .schema import GroupSchema +from .service import GroupService + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +api: Namespace = Namespace( + "Group", + description="Group submission and management operations", +) + + +@api.route("/") +class GroupResource(Resource): + """Shows a list of all Group, and lets you POST to create new groups.""" + + @inject + def __init__( + self, + *args, + group_service: GroupService, + **kwargs, + ) -> None: + self._group_service = group_service + super().__init__(*args, **kwargs) + + @responds(schema=GroupSchema(many=True), api=api) + def get(self) -> List[Group]: + """Gets a list of all groups.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="group", request_type="GET" + ) # noqa: F841 + log.info("Request received") + return self._group_service.get_all(log=log) + + @accepts(GroupSchema, api=api) + @responds(schema=GroupSchema, api=api) + def post(self) -> Group: + """Creates a new Group via a group submission form with an attached file.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="group", request_type="POST" + ) # noqa: F841 + + log.info("Request received") + + parsed_obj = request.parsed_obj # type: ignore + name = slugify(str(parsed_obj["group_name"])) + return self._group_service.submit(name=name, log=log) + + @accepts(GroupSchema, api=api) + def delete(self) -> dict[str, Any]: + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="group", request_type="POST" + ) # noqa: F841 + + log.info("Request received") + + parsed_obj = request.parsed_obj # type: ignore + group_id = int(parsed_obj["id"]) + return self._group_service.delete(id=group_id) + + +@api.route("/") +@api.param("groupId", "A string specifying a group's UUID.") +class GroupIdResource(Resource): + """Shows a single job.""" + + @inject + def __init__(self, *args, _service: GroupService, **kwargs) -> None: + self._group_service = _service + super().__init__(*args, **kwargs) + + @responds(schema=GroupSchema, api=api) + def get(self, groupId: int) -> Group: + """Gets a group by its unique identifier.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="groupId", request_type="GET" + ) # noqa: F841 + log.info("Request received", group_id=groupId) + group: Optional[Group] = self._group_service.get_by_id(groupId, log=log) + + if group is None: + log.error("Group not found", group_id=groupId) + raise GroupDoesNotExistError + + return group diff --git a/src/dioptra/restapi/group/dependencies.py b/src/dioptra/restapi/group/dependencies.py new file mode 100644 index 000000000..a515d71a9 --- /dev/null +++ b/src/dioptra/restapi/group/dependencies.py @@ -0,0 +1,51 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Binding configurations to shared services using dependency injection.""" +from __future__ import annotations + +from typing import Any, Callable + +from injector import Binder, Module, provider + +from .service import GroupService + + +class GroupServiceModule(Module): + @provider + def provide_queue_name_service_module( + self, + ) -> GroupService: + return GroupService() + + +def bind_dependencies(binder: Binder) -> None: + """Binds interfaces to implementations within the main application. + + Args: + binder: A :py:class:`~injector.Binder` object. + """ + pass + + +def register_providers(modules: list[Callable[..., Any]]) -> None: + """Registers type providers within the main application. + + Args: + modules: A list of callables used for configuring the dependency injection + environment. + """ + modules.append(GroupServiceModule) diff --git a/src/dioptra/restapi/group/errors.py b/src/dioptra/restapi/group/errors.py new file mode 100644 index 000000000..32b98c4dc --- /dev/null +++ b/src/dioptra/restapi/group/errors.py @@ -0,0 +1,44 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Error handlers for the group endpoints.""" +from __future__ import annotations + +from flask_restx import Api + + +class GroupDoesNotExistError(Exception): + """The requested group does not exist.""" + + +class GroupSubmissionError(Exception): + """The Group submission form contains invalid parameters.""" + + +def register_error_handlers(api: Api) -> None: + @api.errorhandler(GroupDoesNotExistError) + def handle_job_does_not_exist_error(error): + return {"message": "Not Found - The requested group does not exist"}, 404 + + @api.errorhandler(GroupSubmissionError) + def handle_job_submission_error(error): + return ( + { + "message": "Bad Request - The group submission form contains " + "invalid parameters. Please verify and resubmit." + }, + 400, + ) diff --git a/src/dioptra/restapi/group/model.py b/src/dioptra/restapi/group/model.py new file mode 100644 index 000000000..1bbb2ba10 --- /dev/null +++ b/src/dioptra/restapi/group/model.py @@ -0,0 +1,99 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The data models for the job endpoint objects.""" +from __future__ import annotations + +from dioptra.restapi.app import db +from dioptra.restapi.group_membership.model import GroupMembership +from dioptra.restapi.user.model import User + + +class Group(db.Model): + """The Groups table. + + Attributes: + group_id: The unique identifier of the group. + name: Human-readable name for the group. + creator_id: The id for the user that created the group. + owner_id: The id for the user that owns the group. + created_on: The time at which the group was created. + deleted: Whether the group has been deleted. + """ + + __tablename__ = "groups" + + group_id = db.Column(db.BigInteger(), primary_key=True) + name = db.Column(db.String(36)) + + creator_id = db.Column(db.BigInteger(), db.ForeignKey("users.user_id"), index=True) + owner_id = db.Column(db.BigInteger(), db.ForeignKey("users.user_id"), index=True) + + created_on = db.Column(db.DateTime()) + deleted = db.Column(db.Boolean) + + creator = db.relationship("User", foreign_keys=[creator_id]) + owner = db.relationship("User", foreign_keys=[owner_id]) + + @classmethod + def next_id(cls) -> int: + """Generates the next id in the sequence.""" + group: Group | None = cls.query.order_by(cls.group_id.desc()).first() + + if group is None: + return 1 + + return int(group.id) + 1 + + @property + def users(self): + """The users that are members of the group.""" + return ( + User.query.join(GroupMembership) + .filter(GroupMembership.group_id == self.group_id) + .all() + ) + + def check_membership(self, user: User) -> bool: + """Check if the user has permission to perform the specified action. + + Args: + user: The user to check. + action: The action to check. + + Returns: + True if the user has permission to perform the action, False otherwise. + """ + membership = GroupMembership.query.filter_by( + GroupMembership.user_id == user.user_id, + GroupMembership.group_id == self.group_id, + ) + + if membership is None: + return False + else: + return True + + def update(self, changes: dict): + """Updates the record. + + Args: + changes: A dictionary containing record updates. + """ + for key, val in changes.items(): + setattr(self, key, val) + + return self diff --git a/src/dioptra/restapi/group/routes.py b/src/dioptra/restapi/group/routes.py new file mode 100644 index 000000000..2786cd476 --- /dev/null +++ b/src/dioptra/restapi/group/routes.py @@ -0,0 +1,41 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Methods for registering the group endpoint routes with the main application. + +.. |Api| replace:: :py:class:`flask_restx.Api` +.. |Flask| replace:: :py:class:`flask.Flask` +""" +from __future__ import annotations + +from flask import Flask +from flask_restx import Api + +BASE_ROUTE: str = "group" + + +def register_routes(api: Api, app: Flask, root: str = "api") -> None: + """Registers the job endpoint routes with the main application. + + Args: + api: The main REST |Api| object. + app: The main |Flask| application. + root: The root path for the registration prefix of the namespace. The default + is `"api"`. + """ + from .controller import api as endpoint_api + + api.add_namespace(endpoint_api, path=f"/{root}/{BASE_ROUTE}") diff --git a/src/dioptra/restapi/group/schema.py b/src/dioptra/restapi/group/schema.py new file mode 100644 index 000000000..9c8d67c4d --- /dev/null +++ b/src/dioptra/restapi/group/schema.py @@ -0,0 +1,80 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The schemas for serializing/deserializing the job endpoint objects. + +.. |Job| replace:: :py:class:`~.model.Job` +.. |JobForm| replace:: :py:class:`~.model.JobForm` +.. |JobFormData| replace:: :py:class:`~.model.JobFormData` +""" +from __future__ import annotations + +from typing import Any, Dict + +from marshmallow import Schema, fields, post_load + +from .model import Group + + +class GroupSchema(Schema): + """The schema for the data stored in a |Group| object. + + Attributes: + group_id: The unique identifier of the group. + name: Human-readable name for the group. + creator_id: The id for the user that created the group. + owner_id: The id for the user that owns the group. + created_on: The time at which the group was created. + deleted: Whether the group has been deleted. + """ + + __model__ = Group + + group_id = fields.Integer( + attribute="id", metadata=dict(description="A UUID that identifies the group.") + ) + name = fields.String( + attribute="name", + allow_none=True, # should we force the user to pick a name? + metadata=dict( + description="Human-readable name for the group.", + ), + ) + creator_id = fields.Integer( + attribute="creator_id", + metadata=dict( + description="An integer identifying" "the user that created the group." + ), + ) + owner_id = fields.Integer( + attribute="owner_id", + metadata=dict( + description="An integer identifying the user that owns" "the group." + ), + ) + createdOn = fields.DateTime( + attribute="created_on", + metadata=dict(description="The date and time the group was created."), + ) + deleted = fields.Boolean( + attribute="deleted", + metadata=dict(description="Whether the group has been deleted."), + ) + + @post_load + def deserialize_object(self, data: Dict[str, Any], many: bool, **kwargs) -> Group: + """Creates a |Job| object from the validated data.""" + return self.__model__(**data) diff --git a/src/dioptra/restapi/group/service.py b/src/dioptra/restapi/group/service.py new file mode 100644 index 000000000..650301e68 --- /dev/null +++ b/src/dioptra/restapi/group/service.py @@ -0,0 +1,136 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The server-side functions that perform job endpoint operations.""" +from __future__ import annotations + +import datetime +from typing import Any, List, cast + +import structlog +from sqlalchemy.exc import IntegrityError +from structlog.stdlib import BoundLogger + +from dioptra.restapi.app import db + +from .errors import GroupDoesNotExistError +from .model import Group + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + + +class GroupService(object): + """The service methods for registering and managing groups by their unique id.""" + + @staticmethod + def create(name: str, user_id=None, **kwargs) -> Group: + """Create a new group. + + Args: + name: The name of the group. + user_id: The id of the user creating the group. + + Returns: + The newly created group object. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + timestamp = datetime.datetime.now() + + # #to be used when user is fully implemented + # if user_id is None: + # user_id= current_user.id + + return Group( + group_id=Group.next_id(), + name=name, + creator_id=user_id, + owner_id=user_id, + created_on=timestamp, + deleted=False, + ) + + @staticmethod + def get_all(**kwargs) -> List[Group]: + """Fetch the list of all groups. + + Returns: + A list of groups. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + log.info("Get full list of groups.") + + return Group.query.all() # type: ignore + + @staticmethod + def get_by_id( + group_id: int, error_if_not_found: bool = False, **kwargs + ) -> Group | None: + """Fetch a group by its unique id. + + Args: + group_id: The unique id of the group. + + Returns: + The group object if found, otherwise None. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + log.info("Get group by id", group_id=group_id) + group = Group.query.filter_by(group_id=group_id, deleted=False).first() + + if group is None: + if error_if_not_found: + log.error("Group not found", group_id=group_id) + raise GroupDoesNotExistError + return None + + return cast(Group, group) + + def submit(self, name: str, user_id=None, **kwargs) -> Group: + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + new_group: Group = self.create(name, user_id, log=log) + + db.session.add(new_group) + db.session.commit() + + log.info("Group submission successful", group_id=new_group.group_id) + + return new_group + + def delete(self, id: int, **kwargs) -> dict[str, Any]: + """Delete a group. + + Args: + group_id: The unique id of the group. + + Returns: + A dictionary reporting the status of the request. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + if (group := self.get_by_id(id, log=log)) is None: + return {"status": "Success", "id": []} + group.update(changes={"deleted": True}) + try: + db.session.commit() + + log.info("Group deleted", group_id=id) + return {"status": "Success", "id": [id]} + except IntegrityError: + db.session.rollback() + return {"status": "Failure", "id": [id]} diff --git a/src/dioptra/restapi/group_membership/__init__.py b/src/dioptra/restapi/group_membership/__init__.py new file mode 100644 index 000000000..da65ce662 --- /dev/null +++ b/src/dioptra/restapi/group_membership/__init__.py @@ -0,0 +1,28 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The group membership endpoint subpackage.""" + +from .dependencies import bind_dependencies, register_providers +from .errors import register_error_handlers +from .routes import register_routes + +__all__ = [ + "bind_dependencies", + "register_error_handlers", + "register_providers", + "register_routes", +] diff --git a/src/dioptra/restapi/group_membership/controller.py b/src/dioptra/restapi/group_membership/controller.py new file mode 100644 index 000000000..34a118c60 --- /dev/null +++ b/src/dioptra/restapi/group_membership/controller.py @@ -0,0 +1,104 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The module defining the group membership endpoints.""" +from __future__ import annotations + +import uuid + +import structlog +from flask import request +from flask_accepts import accepts, responds +from flask_restx import Namespace, Resource +from injector import inject +from structlog.stdlib import BoundLogger + +from .model import GroupMembership +from .schema import GroupMembershipSchema +from .service import GroupMembershipService + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +api: Namespace = Namespace( + "GroupMembership", + description="Add users to groups", +) + + +@api.route("/") +class GroupMembershipResource(Resource): + """Manage group memberships.""" + + @inject + def __init__( + self, + *args, + group_membership_service: GroupMembershipService, + **kwargs, + ) -> None: + self._group_membership_service = group_membership_service + super().__init__(*args, **kwargs) + + @responds(schema=GroupMembershipSchema(many=True), api=api) + def get(self) -> list[GroupMembership]: + """Get a list of all group memberships.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="group_membership", + request_type="GET", + ) + log.info("Request received") + return self._group_membership_service.get_all(log=log) + + @accepts(GroupMembershipSchema, api=api) + @responds(schema=GroupMembershipSchema, api=api) + def post(self) -> GroupMembership: + """Create a new group membership using a group membership submission form.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="group_membership", + request_type="POST", + ) + + log.info("Request received") + + parsed_obj = request.parsed_obj # type: ignore + group_id = int(parsed_obj["group_id"]) + user_id = int(parsed_obj["user_id"]) + read = bool(parsed_obj["read"]) + write = bool(parsed_obj["write"]) + share_read = bool(parsed_obj["share_read"]) + share_write = bool(parsed_obj["share_write"]) + + return self._group_membership_service.submit( + group_id, user_id, read, write, share_read, share_write, log=log + ) + + @accepts(GroupMembershipSchema, api=api) + def delete(self) -> bool: + """Delete a group membership.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="group_membership", + request_type="DELETE", + ) + + log.info("Request received") + + parsed_obj = request.parsed_obj # type: ignore + group_id = int(parsed_obj["group_id"]) + user_id = int(parsed_obj["user_id"]) + return self._group_membership_service.delete(group_id, user_id, log=log) diff --git a/src/dioptra/restapi/group_membership/dependencies.py b/src/dioptra/restapi/group_membership/dependencies.py new file mode 100644 index 000000000..bf5e5428d --- /dev/null +++ b/src/dioptra/restapi/group_membership/dependencies.py @@ -0,0 +1,51 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Binding configurations to shared services using dependency injection.""" +from __future__ import annotations + +from typing import Any, Callable + +from injector import Binder, Module, provider + +from .service import GroupMembershipService + + +class GroupMembershipServiceModule(Module): + @provider + def provide_queue_name_service_module( + self, + ) -> GroupMembershipService: + return GroupMembershipService() + + +def bind_dependencies(binder: Binder) -> None: + """Binds interfaces to implementations within the main application. + + Args: + binder: A :py:class:`~injector.Binder` object. + """ + pass + + +def register_providers(modules: list[Callable[..., Any]]) -> None: + """Registers type providers within the main application. + + Args: + modules: A list of callables used for configuring the dependency injection + environment. + """ + modules.append(GroupMembershipServiceModule) diff --git a/src/dioptra/restapi/group_membership/errors.py b/src/dioptra/restapi/group_membership/errors.py new file mode 100644 index 000000000..a874a13f9 --- /dev/null +++ b/src/dioptra/restapi/group_membership/errors.py @@ -0,0 +1,46 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Error handlers for the group membership endpoints.""" +from __future__ import annotations + +from flask_restx import Api + + +class GroupMembershipDoesNotExistError(Exception): + """The requested group membership does not exist.""" + + +class GroupMembershipSubmissionError(Exception): + """The group membership submission form contains invalid parameters.""" + + +def register_error_handlers(api: Api) -> None: + @api.errorhandler(GroupMembershipDoesNotExistError) + def handle_job_does_not_exist_error(error): + return { + "message": "Not Found - The requested group membership does not exist" + }, 404 + + @api.errorhandler(GroupMembershipSubmissionError) + def handle_job_submission_error(error): + return ( + { + "message": "Bad Request - The group membership submission form contains" + "invalid parameters. Please verify and resubmit." + }, + 400, + ) diff --git a/src/dioptra/restapi/group_membership/model.py b/src/dioptra/restapi/group_membership/model.py new file mode 100644 index 000000000..7c914905e --- /dev/null +++ b/src/dioptra/restapi/group_membership/model.py @@ -0,0 +1,67 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The data models for the job endpoint objects.""" +from __future__ import annotations + +from dioptra.restapi.app import db + + +class GroupMembership(db.Model): + """The group memberships table. + + Attributes: + user_id (int): The ID of the user who is a member of the group. + group_id (str): The ID of the group to which the user belongs. + read (bool): Indicates whether the user has read permissions in the group. + write (bool): Indicates whether the user has write permissions in the group. + share_read (bool): Indicates whether the user can share read permissions with + others in the group. + share_write (bool): Indicates whether the user can share write permissions + with others in the group. + """ + + __tablename__ = "group_memberships" + + user_id = db.Column( + db.BigInteger(), db.ForeignKey("users.user_id"), primary_key=True + ) + group_id = db.Column( + db.BigInteger(), db.ForeignKey("groups.group_id"), primary_key=True + ) + + read = db.Column(db.Boolean, default=False) + write = db.Column(db.Boolean, default=False) + share_read = db.Column(db.Boolean, default=False) + share_write = db.Column(db.Boolean, default=False) + + # is back populates needed? + user = db.relationship( + "User", + foreign_keys=[user_id], + ) + group = db.relationship("Group", foreign_keys=[group_id]) + + def update(self, changes: dict): + """Updates the record. + + Args: + changes: A dictionary containing record updates. + """ + for key, val in changes.items(): + setattr(self, key, val) + + return self diff --git a/src/dioptra/restapi/group_membership/routes.py b/src/dioptra/restapi/group_membership/routes.py new file mode 100644 index 000000000..870ce937d --- /dev/null +++ b/src/dioptra/restapi/group_membership/routes.py @@ -0,0 +1,42 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Methods for registering the group membership endpoint routes with the main +application. + +.. |Api| replace:: :py:class:`flask_restx.Api` +.. |Flask| replace:: :py:class:`flask.Flask` +""" +from __future__ import annotations + +from flask import Flask +from flask_restx import Api + +BASE_ROUTE: str = "group_membership" + + +def register_routes(api: Api, app: Flask, root: str = "api") -> None: + """Registers the job endpoint routes with the main application. + + Args: + api: The main REST |Api| object. + app: The main |Flask| application. + root: The root path for the registration prefix of the namespace. The default + is `"api"`. + """ + from .controller import api as endpoint_api + + api.add_namespace(endpoint_api, path=f"/{root}/{BASE_ROUTE}") diff --git a/src/dioptra/restapi/group_membership/schema.py b/src/dioptra/restapi/group_membership/schema.py new file mode 100644 index 000000000..6bc2ac88f --- /dev/null +++ b/src/dioptra/restapi/group_membership/schema.py @@ -0,0 +1,88 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The schemas for serializing/deserializing the job endpoint objects. + +.. |Job| replace:: :py:class:`~.model.Job` +.. |JobForm| replace:: :py:class:`~.model.JobForm` +.. |JobFormData| replace:: :py:class:`~.model.JobFormData` +""" +from __future__ import annotations + +from typing import Any, Dict + +from marshmallow import Schema, fields, post_load + +from .model import GroupMembership + + +class GroupMembershipSchema(Schema): + """The schema for the data stored in a GroupMembership object. + + Attributes: + user_id: The ID of the user who is a member of the group. + group_id: The ID of the group to which the user belongs. + read: Indicates whether the user has read permissions in the group. + write: Indicates whether the user has write permissions in the group. + share_read: Indicates whether the user can share read permissions with others + in the group. + share_write: Indicates whether the user can share write permissions with + others in the group. + """ + + __model__ = GroupMembership + + user_id = fields.Integer( + attribute="user_id", + metadata=dict(description="The ID of the user who is a member of the group."), + ) + group_id = fields.String( + attribute="group_id", + metadata=dict(description="The ID of the group to which the user belongs."), + ) + read = fields.Boolean( + attribute="read", + metadata=dict( + description="Indicates whether the user has read permissions in the group." + ), + ) + write = fields.Boolean( + attribute="write", + metadata=dict( + description="Indicates whether the user has write permissions in the group." + ), + ) + share_read = fields.Boolean( + attribute="share_read", + metadata=dict( + description="Indicates whether the user can share read permissions with \ + others in the group." + ), + ) + share_write = fields.Boolean( + attribute="share_write", + metadata=dict( + description="Indicates whether the user can share write permissions with \ + others in the group." + ), + ) + + @post_load + def deserialize_object( + self, data: Dict[str, Any], many: bool, **kwargs + ) -> GroupMembership: + """Creates a GroupMembership object from the validated data.""" + return self.__model__(**data) diff --git a/src/dioptra/restapi/group_membership/service.py b/src/dioptra/restapi/group_membership/service.py new file mode 100644 index 000000000..1fb1054ff --- /dev/null +++ b/src/dioptra/restapi/group_membership/service.py @@ -0,0 +1,106 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The server-side functions that perform group membership endpoint operations.""" +from __future__ import annotations + +from typing import List + +import structlog +from sqlalchemy.exc import IntegrityError +from structlog.stdlib import BoundLogger + +from dioptra.restapi.app import db + +from .model import GroupMembership + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + + +class GroupMembershipService(object): + @staticmethod + def create( + group_id: int, + user_id: int, + read: bool, + write: bool, + share_read: bool, + share_write: bool, + **kwargs, + ) -> GroupMembership: + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841\ + + return GroupMembership( + group_id=group_id, + user_id=user_id, + read=read, + write=write, + share_read=share_read, + share_write=share_write, + ) + + @staticmethod + def get_all(**kwargs) -> List[GroupMembership]: + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + return GroupMembership.query.all() # type: ignore + + @staticmethod + def get_by_id(group_id: int, user_id: int, **kwargs) -> GroupMembership | None: + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + return GroupMembership.query.filter( # type: ignore + GroupMembership.user_id == user_id, GroupMembership.group_id == group_id + ).first() + + def submit( + self, + group_id: int, + user_id: int, + read: bool, + write: bool, + share_read: bool, + share_write: bool, + **kwargs, + ) -> GroupMembership: + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + new_group_membership: GroupMembership = self.create( + group_id, user_id, read, write, share_read, share_write, log=log + ) + + db.session.add(new_group_membership) + db.session.commit() + + log.info( + "Group Membership submission successful", + group_id=new_group_membership.group_id, + user_id=new_group_membership.user_id, + ) + + return new_group_membership + + def delete(self, group_id, user_id, **kwargs) -> bool: + membership = self.get_by_id(group_id=group_id, user_id=user_id) + + try: + db.session.delete(membership) + db.session.commit() + + return True + except IntegrityError: + db.session.rollback() + return False diff --git a/src/dioptra/restapi/models.py b/src/dioptra/restapi/models.py index 925b9c928..bca58e775 100644 --- a/src/dioptra/restapi/models.py +++ b/src/dioptra/restapi/models.py @@ -22,6 +22,8 @@ ExperimentRegistrationForm, ExperimentRegistrationFormData, ) +from .group.model import Group +from .group_membership.model import GroupMembership from .job.model import Job, JobForm, JobFormData from .queue.model import Queue, QueueLock from .task_plugin.model import ( @@ -44,4 +46,6 @@ "TaskPluginUploadForm", "TaskPluginUploadFormData", "User", + "Group", + "GroupMembership", ] diff --git a/src/dioptra/restapi/routes.py b/src/dioptra/restapi/routes.py index 1718ad720..736011923 100644 --- a/src/dioptra/restapi/routes.py +++ b/src/dioptra/restapi/routes.py @@ -34,6 +34,8 @@ def register_routes(api: Api, app: Flask) -> None: """ from .auth import register_routes as attach_auth from .experiment import register_routes as attach_experiment + from .group import register_routes as attach_group + from .group_membership import register_routes as attach_group_membership from .job import register_routes as attach_job from .queue import register_routes as attach_job_queue from .task_plugin import register_routes as attach_task_plugin @@ -45,3 +47,5 @@ def register_routes(api: Api, app: Flask) -> None: attach_job_queue(api, app) attach_task_plugin(api, app) attach_user(api, app) + attach_group(api, app) + attach_group_membership(api, app) diff --git a/tests/unit/restapi/group/__init__.py b/tests/unit/restapi/group/__init__.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/tests/unit/restapi/group/__init__.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/tests/unit/restapi/group/test_group.py b/tests/unit/restapi/group/test_group.py new file mode 100644 index 000000000..77d26ce45 --- /dev/null +++ b/tests/unit/restapi/group/test_group.py @@ -0,0 +1,477 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Test suite for queue operations. + +This module contains a set of tests that validate the CRUD operations and additional +functionalities for the queue entity. The tests ensure that the queues can be +registered, renamed, deleted, and locked/unlocked as expected through the REST API. +""" +from __future__ import annotations + +import datetime +from typing import Any + +import pytest +from flask.testing import FlaskClient +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from werkzeug.test import TestResponse + +from dioptra.restapi.group.model import Group +from dioptra.restapi.group.service import GroupService +from dioptra.restapi.group_membership.service import GroupMembershipService +from dioptra.restapi.queue.routes import BASE_ROUTE as QUEUE_BASE_ROUTE +from dioptra.restapi.user.model import User +from dioptra.restapi.group.model import Group +from dioptra.restapi.group_membership.model import GroupMembership + + +@pytest.fixture +def group_service() -> GroupService: + yield GroupService() + + +@pytest.fixture +def group_membership_service() -> GroupMembershipService: + yield GroupMembershipService() + + +###### helpers + + +def create_user(db: SQLAlchemy) -> User: + """Create a user and add them to the database. + + Args: + db: The SQLAlchemy database session. + + Returns: + The newly created user object. + """ + timestamp = datetime.datetime.now() + user_expire_on = datetime.datetime(9999, 12, 31, 23, 59, 59) + password_expire_on = timestamp.replace(year=timestamp.year + 1) + + new_user: User = User( + username="test_admin", + password="password", + email_address="test@test.com", + created_on=timestamp, + last_modified_on=timestamp, + last_login_on=timestamp, + password_expire_on=password_expire_on, + ) + db.session.add(new_user) + db.session.commit() + return new_user + + +def create_group(group_service: GroupService, name: str = "test") -> Group: + """Create a group using the group service. + + Args: + group_service: The group service responsible for group creation. + name: The name to assign to the new group (default is "test"). + + Returns: + The response from the group service representing the newly created group. + """ + return group_service.submit(name) + + +def get_group(id: int, group_service: GroupService) -> Group | None: + """Retrieve a group by its unique identifier. + + Args: + id: The unique identifier of the group. + group_service: The service responsible for handling group-related operations. + + Returns: + The retrieved Group object. + + Raises: + GroupDoesNotExistError: If no group with the specified ID is found. + """ + return group_service.get_by_id(id) + + +def delete_group(id: int, group_service: GroupService) -> dict[str, Any]: + """Delete a group by its unique identifier. + + Args: + id: The unique identifier of the group. + group_service: The service responsible for handling group operations. + + Returns: + A dictionary reporting the status of the request. + + Raises: + GroupDoesNotExistError: If the group with the specified ID does not exist. + """ + return group_service.delete(id) + + +def create_group_membership( + group_id: int, + user_id: int, + read: bool, + write: bool, + share_read: bool, + share_write: bool, + group_membership_service: GroupMembershipService, +) -> GroupMembershipService | None: + """Create a group membership. + + Args: + group_id: The unique identifier of the group. + user_id: The unique identifier of the user. + read: Whether the user has read permissions. + write: Whether the user has write permissions. + share_read: Whether the user has share-read permissions. + share_write: Whether the user has share-write permissions. + group_membership_service: The service responsible for handling group memberships. + + Returns: + The created GroupMembership object representing the group membership. + + Raises: + GroupMembershipSubmissionError: If there is an issue with the submission. + """ + return group_membership_service.submit( + group_id, + user_id, + read=read, + write=write, + share_read=share_read, + share_write=share_write, + ) + + +def get_group_membership( + user_id: int, group_id: int, group_membership_service: GroupMembershipService +) -> GroupMembership | None: + """Retrieve a group membership for a user in a specific group. + + Args: + user_id: The unique identifier of the user. + group_id: The unique identifier of the group. + group_membership_service: The service responsible for handling group membership operations. + + Returns: + The retrieved GroupMembership object if found, otherwise None. + + Raises: + GroupMembershipDoesNotExistError: If no group membership with the specified user and group IDs is found. + """ + return group_membership_service.get_by_id(group_id, user_id) + + +def delete_group_membership( + group_id: int, user_id: int, group_membership_service: Any +) -> dict[str, Any]: + """Delete a group membership. + + Args: + group_id: The unique identifier of the group. + user_id: The unique identifier of the user. + group_membership_service: The service responsible for group membership operations. + + Returns: + A dictionary reporting the status of the request. + """ + return group_membership_service.delete(group_id, user_id) + + +##### asserts + + +def assert_group_membership_created(membership: Any, group: Any, new_user: Any) -> None: + """Assert that the group membership has been created. + + Args: + membership: The created group membership object. + group: The group to which the user is added. + new_user: The user added to the group. + """ + assert membership.group_id == group.group_id + assert membership.user_id == new_user.user_id + assert membership.read is True + assert membership.write is True + assert membership.share_read is True + assert membership.share_write is True + + assert new_user in group.users + + +def assert_group_membership_does_not_exist( + membership: Any, group: Any, new_user: Any +) -> None: + """Assert that the group membership has not been created. + + Args: + membership: The created group membership object (should be None). + group: The group to which the user should not be added. + new_user: The user that should not be added to the group. + """ + assert membership is None + + +def assert_membership_group( + retrieved_membership: GroupMembership, + group: Group, +) -> None: + """Assert that the group in the retrieved membership matches the expected group. + + Args: + retrieved_membership: The retrieved membership object. + group: The expected group object. + + Raises: + AssertionError: If the groups do not match. + """ + assert retrieved_membership.group == group + + +def assert_group_in_list( + group: Group, + group_service: GroupService, +) -> None: + """Assert that the group is in the list of all groups retrieved from the service. + + Args: + group: The group object to check. + group_service: The group service. + + Raises: + AssertionError: If the group is not found in the list. + """ + assert group in group_service.get_all() + + +def assert_group_is_none(retrieved_group: Group | None) -> None: + """Assert that the retrieved group is None. + + Args: + retrieved_group: The retrieved group object. + + Raises: + AssertionError: If the retrieved group is not None. + """ + assert retrieved_group is None + + +def assert_membership_user_equals( + retrieved_membership: GroupMembership, new_user: User +) -> None: + """Assert that the user in the retrieved membership equals the new user. + + Args: + retrieved_membership: The retrieved group membership object. + new_user: The new user object. + + Raises: + AssertionError: If the user in the retrieved membership does not equal the new user. + """ + assert retrieved_membership.user == new_user + + +###### tests + + +def test_create_group(db: SQLAlchemy, group_service: GroupService): + """Test the creation of a group using the database and group service. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group creation. + """ + group = create_group(group_service, name="Test Group") + + assert_group_in_list(group, group_service) + + +def test_delete_group(db: SQLAlchemy, group_service: GroupService): + """Test the deletion of a group using the database and group service. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + """ + group = create_group(group_service, name="Test Group") + delete_group(group.group_id, group_service) + retrieved_group = get_group(group.group_id, group_service) + + assert_group_is_none(retrieved_group) + + +def test_create_group_membership( + db: SQLAlchemy, + group_service: GroupService, + group_membership_service: GroupMembershipService, +) -> None: + """Test the creation of a group membership using the database and services. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + group_membership_service: The group membership service responsible for group membership operations. + """ # Create a user + + new_user = create_user(db) + + # Create a group + group = create_group(group_service, name="Test Group") + + # Create a group membership + membership = create_group_membership( + group.group_id, + new_user.user_id, + read=True, + write=True, + share_read=True, + share_write=True, + group_membership_service=group_membership_service, + ) + + assert_group_membership_created(membership, group, new_user) + + +def test_delete_group_membership( + db: SQLAlchemy, + group_service: GroupService, + group_membership_service: GroupMembershipService, +) -> None: + """Test the deletion of a group membership using the database and services. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + group_membership_service: The group membership service responsible for group membership operations. + """ + # Create a user + + new_user = create_user(db) + + # Create a group + group = create_group(group_service, name="Test Group") + + # Create a group membership + membership = create_group_membership( + group.group_id, + new_user.user_id, + read=True, + write=True, + share_read=True, + share_write=True, + group_membership_service=group_membership_service, + ) + + # get membership from db + retrieved_membership = get_group_membership( + group.group_id, new_user.user_id, group_membership_service + ) + + # and then delete it + delete_group_membership( + retrieved_membership.group_id, + retrieved_membership.user_id, + group_membership_service, + ) + + # get membership from db + retrieved_membership = get_group_membership( + group.group_id, new_user.user_id, group_membership_service + ) + + assert_group_membership_does_not_exist(retrieved_membership, group, new_user) + + +def test_group_relationship( + db: SQLAlchemy, + group_service: GroupService, + group_membership_service: GroupMembershipService, +) -> None: + """Test the relationship between groups and group memberships using the database and services. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + group_membership_service: The group membership service responsible for group membership operations. + """ + # Create a user + + new_user = create_user(db) + + # Create a group + group = create_group(group_service, name="Test Group") + + # Create a group membership + membership = create_group_membership( + group.group_id, + new_user.user_id, + read=True, + write=True, + share_read=True, + share_write=True, + group_membership_service=group_membership_service, + ) + + # get membership from db + retrieved_membership = get_group_membership( + group.group_id, new_user.user_id, group_membership_service + ) + + assert_membership_group(retrieved_membership, group) + + +def test_user_relationship( + db: SQLAlchemy, + group_service: GroupService, + group_membership_service: GroupMembershipService, +) -> None: + """Test the relationship between users and group memberships using the database and services. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + group_membership_service: The group membership service responsible for group membership operations. + """ + # Create a user + + new_user = create_user(db) + + # Create a group + group = create_group(group_service, name="Test Group") + + # Create a group membership + membership = create_group_membership( + group.group_id, + new_user.user_id, + read=True, + write=True, + share_read=True, + share_write=True, + group_membership_service=group_membership_service, + ) + + # get membership from db + retrieved_membership = get_group_membership( + group.group_id, new_user.user_id, group_membership_service + ) + + assert_membership_user_equals(retrieved_membership, new_user) From f0c95be1b8c81956777148d57987a99c01a2caa7 Mon Sep 17 00:00:00 2001 From: abyrne Date: Wed, 29 Nov 2023 14:36:58 -0500 Subject: [PATCH 2/6] chore: ran black and isort --- src/dioptra/restapi/dependencies.py | 6 ++++-- tests/cookiecutter_dioptra_deployment/conftest.py | 2 +- tests/unit/restapi/group/test_group.py | 3 +-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/dioptra/restapi/dependencies.py b/src/dioptra/restapi/dependencies.py index 82d290182..a1912355c 100644 --- a/src/dioptra/restapi/dependencies.py +++ b/src/dioptra/restapi/dependencies.py @@ -56,12 +56,14 @@ def register_providers(modules: List[Callable[..., Any]]) -> None: environment. """ from .experiment import register_providers as attach_experiment_providers + from .group import register_providers as attach_group_providers + from .group_membership import ( + register_providers as attach_group_membership_providers, + ) from .job import register_providers as attach_job_providers from .queue import register_providers as attach_job_queue_providers from .task_plugin import register_providers as attach_task_plugin_providers from .user import register_providers as attach_user_providers - from .group import register_routes as attach_group_providers - from .group_membership import register_routes as attach_group_membership_providers # Append modules to list attach_experiment_providers(modules) diff --git a/tests/cookiecutter_dioptra_deployment/conftest.py b/tests/cookiecutter_dioptra_deployment/conftest.py index 72a3f9937..3e14e2449 100644 --- a/tests/cookiecutter_dioptra_deployment/conftest.py +++ b/tests/cookiecutter_dioptra_deployment/conftest.py @@ -117,7 +117,7 @@ def context(): "image": "node", "namespace": "", "tag": "latest", - "registry": "" + "registry": "", }, "redis": { "image": "redis", diff --git a/tests/unit/restapi/group/test_group.py b/tests/unit/restapi/group/test_group.py index 77d26ce45..b9a21db67 100644 --- a/tests/unit/restapi/group/test_group.py +++ b/tests/unit/restapi/group/test_group.py @@ -34,11 +34,10 @@ from dioptra.restapi.group.model import Group from dioptra.restapi.group.service import GroupService +from dioptra.restapi.group_membership.model import GroupMembership from dioptra.restapi.group_membership.service import GroupMembershipService from dioptra.restapi.queue.routes import BASE_ROUTE as QUEUE_BASE_ROUTE from dioptra.restapi.user.model import User -from dioptra.restapi.group.model import Group -from dioptra.restapi.group_membership.model import GroupMembership @pytest.fixture From a3eba663f909a05e9912d368a6efed043b4d360f Mon Sep 17 00:00:00 2001 From: abyrne Date: Thu, 11 Jan 2024 17:34:25 -0500 Subject: [PATCH 3/6] refactor: add docstrings and cleanup controller layer --- src/dioptra/restapi/group/controller.py | 26 ++-- src/dioptra/restapi/group/schema.py | 38 +++--- src/dioptra/restapi/group/service.py | 69 +++++------ .../restapi/group_membership/controller.py | 12 +- .../restapi/group_membership/errors.py | 12 +- .../restapi/group_membership/schema.py | 42 +++---- .../restapi/group_membership/service.py | 115 +++++++++++++----- tests/unit/restapi/group/test_group.py | 8 +- 8 files changed, 180 insertions(+), 142 deletions(-) diff --git a/src/dioptra/restapi/group/controller.py b/src/dioptra/restapi/group/controller.py index 7d94f7017..2f976bd07 100644 --- a/src/dioptra/restapi/group/controller.py +++ b/src/dioptra/restapi/group/controller.py @@ -18,7 +18,7 @@ from __future__ import annotations import uuid -from typing import Any, List, Optional +from typing import Any, cast import structlog from flask import request @@ -27,11 +27,8 @@ from injector import inject from structlog.stdlib import BoundLogger -from dioptra.restapi.utils import slugify - -from .errors import GroupDoesNotExistError from .model import Group -from .schema import GroupSchema +from .schema import GroupSchema, IdStatusResponseSchema from .service import GroupService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -57,7 +54,7 @@ def __init__( super().__init__(*args, **kwargs) @responds(schema=GroupSchema(many=True), api=api) - def get(self) -> List[Group]: + def get(self) -> list[Group]: """Gets a list of all groups.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="group", request_type="GET" @@ -65,7 +62,7 @@ def get(self) -> List[Group]: log.info("Request received") return self._group_service.get_all(log=log) - @accepts(GroupSchema, api=api) + @accepts(schema=GroupSchema, api=api) @responds(schema=GroupSchema, api=api) def post(self) -> Group: """Creates a new Group via a group submission form with an attached file.""" @@ -76,10 +73,11 @@ def post(self) -> Group: log.info("Request received") parsed_obj = request.parsed_obj # type: ignore - name = slugify(str(parsed_obj["group_name"])) - return self._group_service.submit(name=name, log=log) + name = str(parsed_obj["group_name"]) + return self._group_service.create(name=name, log=log) - @accepts(GroupSchema, api=api) + @accepts(schema=GroupSchema, api=api) + @responds(schema=IdStatusResponseSchema, api=api) def delete(self) -> dict[str, Any]: log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="group", request_type="POST" @@ -109,10 +107,6 @@ def get(self, groupId: int) -> Group: request_id=str(uuid.uuid4()), resource="groupId", request_type="GET" ) # noqa: F841 log.info("Request received", group_id=groupId) - group: Optional[Group] = self._group_service.get_by_id(groupId, log=log) - - if group is None: - log.error("Group not found", group_id=groupId) - raise GroupDoesNotExistError + group = self._group_service.get(groupId, error_if_not_found=True, log=log) - return group + return cast(Group, group) diff --git a/src/dioptra/restapi/group/schema.py b/src/dioptra/restapi/group/schema.py index 9c8d67c4d..a109a8fa1 100644 --- a/src/dioptra/restapi/group/schema.py +++ b/src/dioptra/restapi/group/schema.py @@ -22,26 +22,11 @@ """ from __future__ import annotations -from typing import Any, Dict - -from marshmallow import Schema, fields, post_load - -from .model import Group +from marshmallow import Schema, fields class GroupSchema(Schema): - """The schema for the data stored in a |Group| object. - - Attributes: - group_id: The unique identifier of the group. - name: Human-readable name for the group. - creator_id: The id for the user that created the group. - owner_id: The id for the user that owns the group. - created_on: The time at which the group was created. - deleted: Whether the group has been deleted. - """ - - __model__ = Group + """The schema for the data stored in a |Group| object.""" group_id = fields.Integer( attribute="id", metadata=dict(description="A UUID that identifies the group.") @@ -74,7 +59,18 @@ class GroupSchema(Schema): metadata=dict(description="Whether the group has been deleted."), ) - @post_load - def deserialize_object(self, data: Dict[str, Any], many: bool, **kwargs) -> Group: - """Creates a |Job| object from the validated data.""" - return self.__model__(**data) + +class IdStatusResponseSchema(Schema): + """A simple response for reporting a status for one or more objects.""" + + status = fields.String( + attribute="status", + metadata=dict(description="The status of the request."), + ) + id = fields.List( + fields.Integer(), + attribute="id", + metadata=dict( + description="A list of integers identifying the affected object(s)." + ), + ) diff --git a/src/dioptra/restapi/group/service.py b/src/dioptra/restapi/group/service.py index 650301e68..347b433bd 100644 --- a/src/dioptra/restapi/group/service.py +++ b/src/dioptra/restapi/group/service.py @@ -25,6 +25,7 @@ from structlog.stdlib import BoundLogger from dioptra.restapi.app import db +from dioptra.restapi.utils import slugify from .errors import GroupDoesNotExistError from .model import Group @@ -35,35 +36,7 @@ class GroupService(object): """The service methods for registering and managing groups by their unique id.""" - @staticmethod - def create(name: str, user_id=None, **kwargs) -> Group: - """Create a new group. - - Args: - name: The name of the group. - user_id: The id of the user creating the group. - - Returns: - The newly created group object. - """ - log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 - timestamp = datetime.datetime.now() - - # #to be used when user is fully implemented - # if user_id is None: - # user_id= current_user.id - - return Group( - group_id=Group.next_id(), - name=name, - creator_id=user_id, - owner_id=user_id, - created_on=timestamp, - deleted=False, - ) - - @staticmethod - def get_all(**kwargs) -> List[Group]: + def get_all(self, **kwargs) -> List[Group]: """Fetch the list of all groups. Returns: @@ -75,17 +48,19 @@ def get_all(**kwargs) -> List[Group]: return Group.query.all() # type: ignore - @staticmethod - def get_by_id( - group_id: int, error_if_not_found: bool = False, **kwargs + def get( + self, group_id: int, error_if_not_found: bool = False, **kwargs ) -> Group | None: """Fetch a group by its unique id. Args: group_id: The unique id of the group. - + error_if_not_found: Raise an error if the group cannot be found. Returns: The group object if found, otherwise None. + + Raises: + GroupDoesNotExistError: If the group with group_id cannot be found. """ log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 @@ -100,10 +75,32 @@ def get_by_id( return cast(Group, group) - def submit(self, name: str, user_id=None, **kwargs) -> Group: + def create(self, name: str, user_id=None, **kwargs) -> Group: + """Create a new group. + + Args: + name: The name of the group. + user_id: The id of the user creating the group. + + Returns: + The newly created group object. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - new_group: Group = self.create(name, user_id, log=log) + timestamp = datetime.datetime.now() + + # #to be used when user is fully implemented + # if user_id is None: + # user_id= current_user.id + + new_group = Group( + group_id=Group.next_id(), + name=slugify(name), + creator_id=user_id, + owner_id=user_id, + created_on=timestamp, + deleted=False, + ) db.session.add(new_group) db.session.commit() @@ -123,7 +120,7 @@ def delete(self, id: int, **kwargs) -> dict[str, Any]: """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - if (group := self.get_by_id(id, log=log)) is None: + if (group := self.get(id, log=log)) is None: return {"status": "Success", "id": []} group.update(changes={"deleted": True}) try: diff --git a/src/dioptra/restapi/group_membership/controller.py b/src/dioptra/restapi/group_membership/controller.py index 34a118c60..fd0a489a0 100644 --- a/src/dioptra/restapi/group_membership/controller.py +++ b/src/dioptra/restapi/group_membership/controller.py @@ -18,6 +18,7 @@ from __future__ import annotations import uuid +from typing import Any import structlog from flask import request @@ -27,7 +28,7 @@ from structlog.stdlib import BoundLogger from .model import GroupMembership -from .schema import GroupMembershipSchema +from .schema import GroupMembershipSchema, IdStatusResponseSchema from .service import GroupMembershipService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -63,7 +64,7 @@ def get(self) -> list[GroupMembership]: log.info("Request received") return self._group_membership_service.get_all(log=log) - @accepts(GroupMembershipSchema, api=api) + @accepts(schema=GroupMembershipSchema, api=api) @responds(schema=GroupMembershipSchema, api=api) def post(self) -> GroupMembership: """Create a new group membership using a group membership submission form.""" @@ -83,12 +84,13 @@ def post(self) -> GroupMembership: share_read = bool(parsed_obj["share_read"]) share_write = bool(parsed_obj["share_write"]) - return self._group_membership_service.submit( + return self._group_membership_service.create( group_id, user_id, read, write, share_read, share_write, log=log ) - @accepts(GroupMembershipSchema, api=api) - def delete(self) -> bool: + @accepts(schema=GroupMembershipSchema, api=api) + @responds(schema=IdStatusResponseSchema, api=api) + def delete(self) -> dict[str, Any]: """Delete a group membership.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), diff --git a/src/dioptra/restapi/group_membership/errors.py b/src/dioptra/restapi/group_membership/errors.py index a874a13f9..626c81d34 100644 --- a/src/dioptra/restapi/group_membership/errors.py +++ b/src/dioptra/restapi/group_membership/errors.py @@ -28,15 +28,23 @@ class GroupMembershipSubmissionError(Exception): """The group membership submission form contains invalid parameters.""" +class GroupMembershipAlreadyExistsError(Exception): + """The group membership submission form contains invalid parameters.""" + + def register_error_handlers(api: Api) -> None: @api.errorhandler(GroupMembershipDoesNotExistError) - def handle_job_does_not_exist_error(error): + def handle_group_membership_does_not_exist_error(error): return { "message": "Not Found - The requested group membership does not exist" }, 404 + @api.errorhandler(GroupMembershipAlreadyExistsError) + def handle_group_membership_already_exists_error(error): + return {"message": "Bad Request - The group membership already exists"}, 400 + @api.errorhandler(GroupMembershipSubmissionError) - def handle_job_submission_error(error): + def handle_group_membership_submission_error(error): return ( { "message": "Bad Request - The group membership submission form contains" diff --git a/src/dioptra/restapi/group_membership/schema.py b/src/dioptra/restapi/group_membership/schema.py index 6bc2ac88f..c7e322d39 100644 --- a/src/dioptra/restapi/group_membership/schema.py +++ b/src/dioptra/restapi/group_membership/schema.py @@ -22,28 +22,11 @@ """ from __future__ import annotations -from typing import Any, Dict - -from marshmallow import Schema, fields, post_load - -from .model import GroupMembership +from marshmallow import Schema, fields class GroupMembershipSchema(Schema): - """The schema for the data stored in a GroupMembership object. - - Attributes: - user_id: The ID of the user who is a member of the group. - group_id: The ID of the group to which the user belongs. - read: Indicates whether the user has read permissions in the group. - write: Indicates whether the user has write permissions in the group. - share_read: Indicates whether the user can share read permissions with others - in the group. - share_write: Indicates whether the user can share write permissions with - others in the group. - """ - - __model__ = GroupMembership + """The schema for the data stored in a |GroupMembership| object.""" user_id = fields.Integer( attribute="user_id", @@ -80,9 +63,18 @@ class GroupMembershipSchema(Schema): ), ) - @post_load - def deserialize_object( - self, data: Dict[str, Any], many: bool, **kwargs - ) -> GroupMembership: - """Creates a GroupMembership object from the validated data.""" - return self.__model__(**data) + +class IdStatusResponseSchema(Schema): + """A simple response for reporting a status for one or more objects.""" + + status = fields.String( + attribute="status", + metadata=dict(description="The status of the request."), + ) + id = fields.List( + fields.Integer(), + attribute="id", + metadata=dict( + description="A list of integers identifying the affected object(s)." + ), + ) diff --git a/src/dioptra/restapi/group_membership/service.py b/src/dioptra/restapi/group_membership/service.py index 1fb1054ff..682c99150 100644 --- a/src/dioptra/restapi/group_membership/service.py +++ b/src/dioptra/restapi/group_membership/service.py @@ -17,7 +17,7 @@ """The server-side functions that perform group membership endpoint operations.""" from __future__ import annotations -from typing import List +from typing import Any, cast import structlog from sqlalchemy.exc import IntegrityError @@ -25,48 +25,56 @@ from dioptra.restapi.app import db +from .errors import GroupMembershipDoesNotExistError from .model import GroupMembership LOGGER: BoundLogger = structlog.stdlib.get_logger() class GroupMembershipService(object): - @staticmethod - def create( - group_id: int, - user_id: int, - read: bool, - write: bool, - share_read: bool, - share_write: bool, - **kwargs, - ) -> GroupMembership: - log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841\ + def get_all(self, **kwargs) -> list[GroupMembership]: + """Retrieve a list of all group memberships. - return GroupMembership( - group_id=group_id, - user_id=user_id, - read=read, - write=write, - share_read=share_read, - share_write=share_write, - ) - - @staticmethod - def get_all(**kwargs) -> List[GroupMembership]: + Returns: + List[GroupMembership]: List of group memberships. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 return GroupMembership.query.all() # type: ignore - @staticmethod - def get_by_id(group_id: int, user_id: int, **kwargs) -> GroupMembership | None: + def get( + self, group_id: int, user_id: int, error_if_not_found: bool = False, **kwargs + ) -> GroupMembership | None: + """Retrieve a group membership. + + Args: + group_id (int): The unique ID of the group. + user_id (int): The unique ID of the user. + error_if_not_found (bool): Flag to raise an error if the membership is + not found. + + Returns: + GroupMembership | None: The group membership if found, else None. + + Raises: + GroupMembershipNotFoundError: If the membership is not found and + error_if_not_found is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 - return GroupMembership.query.filter( # type: ignore + membership = GroupMembership.query.filter( GroupMembership.user_id == user_id, GroupMembership.group_id == group_id ).first() - def submit( + if error_if_not_found: + if membership is None: + log.error("Group Membership not found", group_id=group_id) + raise GroupMembershipDoesNotExistError + + return cast(GroupMembership, membership) + + def create( self, group_id: int, user_id: int, @@ -76,10 +84,35 @@ def submit( share_write: bool, **kwargs, ) -> GroupMembership: + """Create a new group membership. + + Args: + group_id (int): The unique ID of the group. + user_id (int): The unique ID of the user. + read (bool): Permission flag for read access. + write (bool): Permission flag for write access. + share_read (bool): Permission flag for sharing with read access. + share_write (bool): Permission flag for sharing with write access. + + Returns: + GroupMembership: The newly created group membership. + + Raises: + GroupMembershipAlreadyExistsError: If the membership already exists. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - new_group_membership: GroupMembership = self.create( - group_id, user_id, read, write, share_read, share_write, log=log + if self.get(group_id=group_id, user_id=user_id) is not None: + log.error("Group Membership already exists", group_id=group_id) + raise GroupMembershipDoesNotExistError + + new_group_membership = GroupMembership( + group_id=group_id, + user_id=user_id, + read=read, + write=write, + share_read=share_read, + share_write=share_write, ) db.session.add(new_group_membership) @@ -93,14 +126,30 @@ def submit( return new_group_membership - def delete(self, group_id, user_id, **kwargs) -> bool: - membership = self.get_by_id(group_id=group_id, user_id=user_id) + def delete(self, group_id, user_id, **kwargs) -> dict[str, Any]: + """Delete a group membership. + + Args: + group_id: The unique ID of the group. + user_id: The unique ID of the user. + + Returns: + A dictionary with the status and IDs of the deleted membership. + Raises: + IntegrityError: If there is an issue with the database integrity during + deletion. + """ + + log: BoundLogger = kwargs.get("log", LOGGER.new()) + if (membership := self.get(group_id=group_id, user_id=user_id)) is None: + return {"status": "Success", "id": []} try: db.session.delete(membership) db.session.commit() - return True + log.info("Group Membership deleted", group_id=group_id, user_id=user_id) + return {"status": "Success", "id": [group_id, user_id]} except IntegrityError: db.session.rollback() - return False + return {"status": "Failure", "id": [group_id, user_id]} diff --git a/tests/unit/restapi/group/test_group.py b/tests/unit/restapi/group/test_group.py index b9a21db67..7db91933d 100644 --- a/tests/unit/restapi/group/test_group.py +++ b/tests/unit/restapi/group/test_group.py @@ -90,7 +90,7 @@ def create_group(group_service: GroupService, name: str = "test") -> Group: Returns: The response from the group service representing the newly created group. """ - return group_service.submit(name) + return group_service.create(name) def get_group(id: int, group_service: GroupService) -> Group | None: @@ -106,7 +106,7 @@ def get_group(id: int, group_service: GroupService) -> Group | None: Raises: GroupDoesNotExistError: If no group with the specified ID is found. """ - return group_service.get_by_id(id) + return group_service.get(id) def delete_group(id: int, group_service: GroupService) -> dict[str, Any]: @@ -151,7 +151,7 @@ def create_group_membership( Raises: GroupMembershipSubmissionError: If there is an issue with the submission. """ - return group_membership_service.submit( + return group_membership_service.create( group_id, user_id, read=read, @@ -177,7 +177,7 @@ def get_group_membership( Raises: GroupMembershipDoesNotExistError: If no group membership with the specified user and group IDs is found. """ - return group_membership_service.get_by_id(group_id, user_id) + return group_membership_service.get(group_id, user_id) def delete_group_membership( From 48f60fc96da854e031063aee3cce445c163e81d0 Mon Sep 17 00:00:00 2001 From: abyrne Date: Mon, 29 Jan 2024 10:46:05 -0500 Subject: [PATCH 4/6] chore: address comments and cleanup --- src/dioptra/restapi/group/service.py | 4 ++-- .../restapi/group_membership/service.py | 22 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/dioptra/restapi/group/service.py b/src/dioptra/restapi/group/service.py index 347b433bd..626244399 100644 --- a/src/dioptra/restapi/group/service.py +++ b/src/dioptra/restapi/group/service.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -from typing import Any, List, cast +from typing import Any, cast import structlog from sqlalchemy.exc import IntegrityError @@ -36,7 +36,7 @@ class GroupService(object): """The service methods for registering and managing groups by their unique id.""" - def get_all(self, **kwargs) -> List[Group]: + def get_all(self, **kwargs) -> list[Group]: """Fetch the list of all groups. Returns: diff --git a/src/dioptra/restapi/group_membership/service.py b/src/dioptra/restapi/group_membership/service.py index 682c99150..b984c46e3 100644 --- a/src/dioptra/restapi/group_membership/service.py +++ b/src/dioptra/restapi/group_membership/service.py @@ -36,7 +36,7 @@ def get_all(self, **kwargs) -> list[GroupMembership]: """Retrieve a list of all group memberships. Returns: - List[GroupMembership]: List of group memberships. + List of group memberships. """ log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 @@ -48,13 +48,13 @@ def get( """Retrieve a group membership. Args: - group_id (int): The unique ID of the group. - user_id (int): The unique ID of the user. - error_if_not_found (bool): Flag to raise an error if the membership is + group_id: The unique ID of the group. + user_id: The unique ID of the user. + error_if_not_found: Flag to raise an error if the membership is not found. Returns: - GroupMembership | None: The group membership if found, else None. + The group membership if found, else None. Raises: GroupMembershipNotFoundError: If the membership is not found and @@ -87,12 +87,12 @@ def create( """Create a new group membership. Args: - group_id (int): The unique ID of the group. - user_id (int): The unique ID of the user. - read (bool): Permission flag for read access. - write (bool): Permission flag for write access. - share_read (bool): Permission flag for sharing with read access. - share_write (bool): Permission flag for sharing with write access. + group_id: The unique ID of the group. + user_id: The unique ID of the user. + read: Permission flag for read access. + write: Permission flag for write access. + share_read: Permission flag for sharing with read access. + share_write: Permission flag for sharing with write access. Returns: GroupMembership: The newly created group membership. From 8d91e82fd7a147aa560921fe955e811a2dc13763 Mon Sep 17 00:00:00 2001 From: abyrne Date: Wed, 31 Jan 2024 16:41:59 -0500 Subject: [PATCH 5/6] chore: Restyled files with updated black formatter --- src/dioptra/pyplugs/_plugins.py | 3 +- src/dioptra/restapi/experiment/service.py | 6 +-- src/dioptra/restapi/task_plugin/controller.py | 28 +++++++------- src/dioptra/restapi/utils.py | 3 +- .../tensorflow_backend/confluence.py | 38 +++++++++---------- .../data/tensorflow_backend.py | 16 ++++---- src/dioptra/task_engine/task_engine.py | 6 +-- .../dioptra_builtins/metrics/distance.py | 12 +++--- .../dioptra_builtins/metrics/performance.py | 6 +-- .../mlflow_tracking/test_entrypoint.py | 4 +- tests/unit/restapi/conftest.py | 15 +++++--- tests/unit/restapi/task_plugin/test_schema.py | 6 +-- tests/unit/restapi/test_user.py | 32 ++++++---------- 13 files changed, 85 insertions(+), 90 deletions(-) diff --git a/src/dioptra/pyplugs/_plugins.py b/src/dioptra/pyplugs/_plugins.py index dae1ee393..df9460f65 100644 --- a/src/dioptra/pyplugs/_plugins.py +++ b/src/dioptra/pyplugs/_plugins.py @@ -104,8 +104,7 @@ class NoutPlugin(Protocol): _task_nout: int - def __call__(self, *args, **kwargs) -> Any: - ... # pragma: nocover + def __call__(self, *args, **kwargs) -> Any: ... # pragma: nocover # Type aliases diff --git a/src/dioptra/restapi/experiment/service.py b/src/dioptra/restapi/experiment/service.py index 7cca47c92..b9b001226 100644 --- a/src/dioptra/restapi/experiment/service.py +++ b/src/dioptra/restapi/experiment/service.py @@ -86,9 +86,9 @@ def create( def create_mlflow_experiment(self, experiment_name: str) -> int: try: - experiment_id: Optional[ - str - ] = self._mlflow_tracking_service.create_experiment(experiment_name) + experiment_id: Optional[str] = ( + self._mlflow_tracking_service.create_experiment(experiment_name) + ) except RestException as exc: raise ExperimentMLFlowTrackingRegistrationError from exc diff --git a/src/dioptra/restapi/task_plugin/controller.py b/src/dioptra/restapi/task_plugin/controller.py index 7e2b5003d..ebbaa529a 100644 --- a/src/dioptra/restapi/task_plugin/controller.py +++ b/src/dioptra/restapi/task_plugin/controller.py @@ -148,13 +148,13 @@ def get(self, taskPluginName: str) -> TaskPlugin: ) log.info("Request received") - task_plugin: Optional[ - TaskPlugin - ] = self._task_plugin_service.get_by_name_in_collection( - collection="dioptra_builtins", - task_plugin_name=taskPluginName, - bucket=current_app.config["DIOPTRA_PLUGINS_BUCKET"], - log=log, + task_plugin: Optional[TaskPlugin] = ( + self._task_plugin_service.get_by_name_in_collection( + collection="dioptra_builtins", + task_plugin_name=taskPluginName, + bucket=current_app.config["DIOPTRA_PLUGINS_BUCKET"], + log=log, + ) ) if task_plugin is None: @@ -219,13 +219,13 @@ def get(self, taskPluginName: str) -> TaskPlugin: ) log.info("Request received") - task_plugin: Optional[ - TaskPlugin - ] = self._task_plugin_service.get_by_name_in_collection( - collection="dioptra_custom", - task_plugin_name=taskPluginName, - bucket=current_app.config["DIOPTRA_PLUGINS_BUCKET"], - log=log, + task_plugin: Optional[TaskPlugin] = ( + self._task_plugin_service.get_by_name_in_collection( + collection="dioptra_custom", + task_plugin_name=taskPluginName, + bucket=current_app.config["DIOPTRA_PLUGINS_BUCKET"], + log=log, + ) ) if task_plugin is None: diff --git a/src/dioptra/restapi/utils.py b/src/dioptra/restapi/utils.py index 52026cb8c..e3a5d52b8 100644 --- a/src/dioptra/restapi/utils.py +++ b/src/dioptra/restapi/utils.py @@ -94,8 +94,7 @@ class _ClassBasedViewFunction(Protocol): view_class: Type[View] - def __call__(self, *args, **kwargs) -> Any: - ... + def __call__(self, *args, **kwargs) -> Any: ... def _new_class_view_function( diff --git a/src/dioptra/sdk/object_detection/bounding_boxes/postprocessing/tensorflow_backend/confluence.py b/src/dioptra/sdk/object_detection/bounding_boxes/postprocessing/tensorflow_backend/confluence.py index 36b587bc3..654b8ef6f 100644 --- a/src/dioptra/sdk/object_detection/bounding_boxes/postprocessing/tensorflow_backend/confluence.py +++ b/src/dioptra/sdk/object_detection/bounding_boxes/postprocessing/tensorflow_backend/confluence.py @@ -259,9 +259,9 @@ def confluence( all_proximities = np.ones_like(proximity) cconf_scores = np.zeros_like(cconf) - all_proximities[ - proximity <= self._confluence_threshold - ] = proximity[proximity <= self._confluence_threshold] + all_proximities[proximity <= self._confluence_threshold] = ( + proximity[proximity <= self._confluence_threshold] + ) cconf_scores[proximity <= self._confluence_threshold] = cconf[ proximity <= self._confluence_threshold ] @@ -334,18 +334,18 @@ def confluence( -((1 - manhattan_distance) * (1 - manhattan_distance)) / self._sigma ) - weights[ - manhattan_distance <= self._confluence_threshold - ] = gaussian_weights[ - manhattan_distance <= self._confluence_threshold - ] + weights[manhattan_distance <= self._confluence_threshold] = ( + gaussian_weights[ + manhattan_distance <= self._confluence_threshold + ] + ) else: - weights[ - manhattan_distance <= self._confluence_threshold - ] = manhattan_distance[ - manhattan_distance <= self._confluence_threshold - ] + weights[manhattan_distance <= self._confluence_threshold] = ( + manhattan_distance[ + manhattan_distance <= self._confluence_threshold + ] + ) dets[1:, 4] *= weights to_reprocess = np.where(dets[1:, 4] >= self._score_threshold)[0] @@ -506,14 +506,14 @@ def pad_retained_arrays( ) if len(batch_scores) > 0: - padded_scores[ - batch_idx, 0 : batch_scores.shape[0] - ] = batch_scores.astype("float32") + padded_scores[batch_idx, 0 : batch_scores.shape[0]] = ( + batch_scores.astype("float32") + ) if len(batch_labels) > 0: - padded_labels[ - batch_idx, 0 : batch_labels.shape[0] - ] = batch_labels.astype("int32") + padded_labels[batch_idx, 0 : batch_labels.shape[0]] = ( + batch_labels.astype("int32") + ) padded_detections[batch_idx] = len(batch_scores) diff --git a/src/dioptra/sdk/object_detection/data/tensorflow_backend.py b/src/dioptra/sdk/object_detection/data/tensorflow_backend.py index d24339355..16ba2cc78 100644 --- a/src/dioptra/sdk/object_detection/data/tensorflow_backend.py +++ b/src/dioptra/sdk/object_detection/data/tensorflow_backend.py @@ -115,15 +115,15 @@ def create( augmentations_seed: Optional[int] = None, shuffle_seed: Optional[int] = None, ) -> TensorflowObjectDetectionData: - annotation_data_registry: dict[ - str, Callable[[], PascalVOCAnnotationData] - ] = dict( - pascal_voc=lambda: PascalVOCAnnotationData( - labels=labels, - encoding=NumpyAnnotationEncoding( - boxes_dtype="float32", labels_dtype="int32" + annotation_data_registry: dict[str, Callable[[], PascalVOCAnnotationData]] = ( + dict( + pascal_voc=lambda: PascalVOCAnnotationData( + labels=labels, + encoding=NumpyAnnotationEncoding( + boxes_dtype="float32", labels_dtype="int32" + ), ), - ), + ) ) augmentations_registry: dict[ str, Callable[[], ImgAugObjectDetectionAugmentations] diff --git a/src/dioptra/task_engine/task_engine.py b/src/dioptra/task_engine/task_engine.py index a8e0b5db1..7a59092cb 100644 --- a/src/dioptra/task_engine/task_engine.py +++ b/src/dioptra/task_engine/task_engine.py @@ -468,9 +468,9 @@ def _run_experiment( ) log.debug("Global parameters:\n %s", props_values) - step_outputs: MutableMapping[ - str, MutableMapping[str, Any] - ] = collections.defaultdict(dict) + step_outputs: MutableMapping[str, MutableMapping[str, Any]] = ( + collections.defaultdict(dict) + ) step_order = util.get_sorted_steps(graph) diff --git a/task-plugins/dioptra_builtins/metrics/distance.py b/task-plugins/dioptra_builtins/metrics/distance.py index 1c74f8692..a737ba2df 100644 --- a/task-plugins/dioptra_builtins/metrics/distance.py +++ b/task-plugins/dioptra_builtins/metrics/distance.py @@ -68,9 +68,9 @@ def get_distance_metric_list( distance_metrics_list: List[Tuple[str, Callable[..., np.ndarray]]] = [] for metric in request: - metric_callable: Optional[ - Callable[..., np.ndarray] - ] = DISTANCE_METRICS_REGISTRY.get(metric["func"]) + metric_callable: Optional[Callable[..., np.ndarray]] = ( + DISTANCE_METRICS_REGISTRY.get(metric["func"]) + ) if metric_callable is not None: distance_metrics_list.append((metric["name"], metric_callable)) @@ -106,9 +106,9 @@ def get_distance_metric(func: str) -> Callable[..., np.ndarray]: Returns: A callable distance metric function. """ - metric_callable: Optional[ - Callable[..., np.ndarray] - ] = DISTANCE_METRICS_REGISTRY.get(func) + metric_callable: Optional[Callable[..., np.ndarray]] = ( + DISTANCE_METRICS_REGISTRY.get(func) + ) if metric_callable is None: LOGGER.error( diff --git a/task-plugins/dioptra_builtins/metrics/performance.py b/task-plugins/dioptra_builtins/metrics/performance.py index d66d36736..29d6147b4 100644 --- a/task-plugins/dioptra_builtins/metrics/performance.py +++ b/task-plugins/dioptra_builtins/metrics/performance.py @@ -69,9 +69,9 @@ def get_performance_metric_list( performance_metrics_list: List[Tuple[str, Callable[..., float]]] = [] for metric in request: - metric_callable: Optional[ - Callable[..., float] - ] = PERFORMANCE_METRICS_REGISTRY.get(metric["func"]) + metric_callable: Optional[Callable[..., float]] = ( + PERFORMANCE_METRICS_REGISTRY.get(metric["func"]) + ) if metric_callable is not None: performance_metrics_list.append((metric["name"], metric_callable)) diff --git a/tests/containers/mlflow_tracking/test_entrypoint.py b/tests/containers/mlflow_tracking/test_entrypoint.py index e7adb986f..9ef038cf7 100644 --- a/tests/containers/mlflow_tracking/test_entrypoint.py +++ b/tests/containers/mlflow_tracking/test_entrypoint.py @@ -54,7 +54,9 @@ def host(container: Container) -> Host: @pytest.fixture(scope="function") def print_db_tables_pyscript(container: Container, tmp_path: Path) -> str: - pyscript: str | bytes = """ + pyscript: ( + str | bytes + ) = """ import sqlite3\n con = sqlite3.connect("/work/mlruns/mlflow-tracking.db") diff --git a/tests/unit/restapi/conftest.py b/tests/unit/restapi/conftest.py index dc8edb11c..41427fa69 100644 --- a/tests/unit/restapi/conftest.py +++ b/tests/unit/restapi/conftest.py @@ -58,9 +58,10 @@ def task_plugins_dir(tmp_path_factory): def workflow_tar_gz(): workflow_tar_gz_fileobj: BinaryIO = io.BytesIO() - with tarfile.open(fileobj=workflow_tar_gz_fileobj, mode="w:gz") as f, io.BytesIO( - initial_bytes=b"data" - ) as data: + with ( + tarfile.open(fileobj=workflow_tar_gz_fileobj, mode="w:gz") as f, + io.BytesIO(initial_bytes=b"data") as data, + ): tarinfo = tarfile.TarInfo(name="MLproject") tarinfo.size = len(data.getbuffer()) f.addfile(tarinfo=tarinfo, fileobj=data) @@ -74,9 +75,11 @@ def workflow_tar_gz(): def task_plugin_archive(): archive_fileobj: BinaryIO = io.BytesIO() - with tarfile.open(fileobj=archive_fileobj, mode="w:gz") as f, io.BytesIO( - initial_bytes=b"# init file" - ) as f_init, io.BytesIO(initial_bytes=b"# plugin module") as f_plugin_module: + with ( + tarfile.open(fileobj=archive_fileobj, mode="w:gz") as f, + io.BytesIO(initial_bytes=b"# init file") as f_init, + io.BytesIO(initial_bytes=b"# plugin module") as f_plugin_module, + ): tarinfo_init = tarfile.TarInfo(name="new_plugin_module/__init__.py") tarinfo_init.size = len(f_init.getbuffer()) f.addfile(tarinfo=tarinfo_init, fileobj=f_init) diff --git a/tests/unit/restapi/task_plugin/test_schema.py b/tests/unit/restapi/task_plugin/test_schema.py index 0758d014a..0bbc03968 100644 --- a/tests/unit/restapi/task_plugin/test_schema.py +++ b/tests/unit/restapi/task_plugin/test_schema.py @@ -111,9 +111,9 @@ def test_TaskPluginUploadFormSchema_dump_works( task_plugin_upload_form_schema: TaskPluginUploadFormSchema, task_plugin_archive: BinaryIO, ) -> None: - task_plugin_upload_form_serialized: Dict[ - str, Any - ] = task_plugin_upload_form_schema.dump(task_plugin_upload_form) + task_plugin_upload_form_serialized: Dict[str, Any] = ( + task_plugin_upload_form_schema.dump(task_plugin_upload_form) + ) assert task_plugin_upload_form_serialized["task_plugin_name"] == "new_plugin_one" assert task_plugin_upload_form_serialized["collection"] == "dioptra_custom" diff --git a/tests/unit/restapi/test_user.py b/tests/unit/restapi/test_user.py index a6cc95311..647ddb463 100644 --- a/tests/unit/restapi/test_user.py +++ b/tests/unit/restapi/test_user.py @@ -200,13 +200,13 @@ def _(client: RequestsSession) -> dict[str, Any]: @overload -def login(client: FlaskClient, username: str, password: str) -> TestResponse: - ... +def login(client: FlaskClient, username: str, password: str) -> TestResponse: ... @overload -def login(client: RequestsSession, username: str, password: str) -> RequestsResponse: - ... +def login( + client: RequestsSession, username: str, password: str +) -> RequestsResponse: ... @singledispatch @@ -244,13 +244,11 @@ def _(client: RequestsSession, username: str, password: str) -> RequestsResponse @overload -def logout(client: FlaskClient, everywhere: bool) -> TestResponse: - ... +def logout(client: FlaskClient, everywhere: bool) -> TestResponse: ... @overload -def logout(client: RequestsSession, everywhere: bool) -> RequestsResponse: - ... +def logout(client: RequestsSession, everywhere: bool) -> RequestsResponse: ... @singledispatch @@ -288,15 +286,13 @@ def _(client: RequestsSession, everywhere: bool) -> RequestsResponse: @overload def change_password( client: FlaskClient, user_id: int, current_password: str, new_password: str -) -> TestResponse: - ... +) -> TestResponse: ... @overload def change_password( client: RequestsSession, user_id: int, current_password: str, new_password: str -) -> RequestsResponse: - ... +) -> RequestsResponse: ... @singledispatch @@ -351,15 +347,13 @@ def _( @overload def change_current_user_password( client: FlaskClient, current_password: str, new_password: str -) -> TestResponse: - ... +) -> TestResponse: ... @overload def change_current_user_password( client: RequestsSession, current_password: str, new_password: str -) -> RequestsResponse: - ... +) -> RequestsResponse: ... @singledispatch @@ -407,13 +401,11 @@ def _( @overload -def delete_current_user(client: FlaskClient, password: str) -> TestResponse: - ... +def delete_current_user(client: FlaskClient, password: str) -> TestResponse: ... @overload -def delete_current_user(client: RequestsSession, password: str) -> RequestsResponse: - ... +def delete_current_user(client: RequestsSession, password: str) -> RequestsResponse: ... @singledispatch From 0fb8175cf58c044f586dcd27569d2c18680b367c Mon Sep 17 00:00:00 2001 From: abyrne Date: Wed, 31 Jan 2024 16:50:14 -0500 Subject: [PATCH 6/6] chore: fix black caused mypy errors --- src/dioptra/pyplugs/_plugins.py | 2 +- src/dioptra/restapi/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dioptra/pyplugs/_plugins.py b/src/dioptra/pyplugs/_plugins.py index df9460f65..95102de2a 100644 --- a/src/dioptra/pyplugs/_plugins.py +++ b/src/dioptra/pyplugs/_plugins.py @@ -104,7 +104,7 @@ class NoutPlugin(Protocol): _task_nout: int - def __call__(self, *args, **kwargs) -> Any: ... # pragma: nocover + def __call__(self, *args, **kwargs) -> Any: ... # pragma: nocover # noqa E704 # Type aliases diff --git a/src/dioptra/restapi/utils.py b/src/dioptra/restapi/utils.py index e3a5d52b8..faa68b6e6 100644 --- a/src/dioptra/restapi/utils.py +++ b/src/dioptra/restapi/utils.py @@ -94,7 +94,7 @@ class _ClassBasedViewFunction(Protocol): view_class: Type[View] - def __call__(self, *args, **kwargs) -> Any: ... + def __call__(self, *args, **kwargs) -> Any: ... # noqa E704 def _new_class_view_function(