diff --git a/src/dioptra/pyplugs/_plugins.py b/src/dioptra/pyplugs/_plugins.py index dae1ee393..95102de2a 100644 --- a/src/dioptra/pyplugs/_plugins.py +++ b/src/dioptra/pyplugs/_plugins.py @@ -104,8 +104,7 @@ class NoutPlugin(Protocol): _task_nout: int - def __call__(self, *args, **kwargs) -> Any: - ... # pragma: nocover + def __call__(self, *args, **kwargs) -> Any: ... # pragma: nocover # noqa E704 # Type aliases diff --git a/src/dioptra/restapi/dependencies.py b/src/dioptra/restapi/dependencies.py index c61bc8f0e..a1912355c 100644 --- a/src/dioptra/restapi/dependencies.py +++ b/src/dioptra/restapi/dependencies.py @@ -29,6 +29,10 @@ def bind_dependencies(binder: Binder) -> None: binder: A :py:class:`~injector.Binder` object. """ from .experiment import bind_dependencies as attach_experiment_dependencies + from .group import bind_dependencies as attach_group_dependencies + from .group_membership import ( + bind_dependencies as attach_group_membership_dependencies, + ) from .job import bind_dependencies as attach_job_dependencies from .queue import bind_dependencies as attach_job_queue_dependencies from .task_plugin import bind_dependencies as attach_task_plugin_dependencies @@ -40,6 +44,8 @@ def bind_dependencies(binder: Binder) -> None: attach_job_queue_dependencies(binder) attach_task_plugin_dependencies(binder) attach_user_dependencies(binder) + attach_group_dependencies(binder) + attach_group_membership_dependencies(binder) def register_providers(modules: List[Callable[..., Any]]) -> None: @@ -50,6 +56,10 @@ def register_providers(modules: List[Callable[..., Any]]) -> None: environment. """ from .experiment import register_providers as attach_experiment_providers + from .group import register_providers as attach_group_providers + from .group_membership import ( + register_providers as attach_group_membership_providers, + ) from .job import register_providers as attach_job_providers from .queue import register_providers as attach_job_queue_providers from .task_plugin import register_providers as attach_task_plugin_providers @@ -61,3 +71,5 @@ def register_providers(modules: List[Callable[..., Any]]) -> None: attach_job_queue_providers(modules) attach_task_plugin_providers(modules) attach_user_providers(modules) + attach_group_providers(modules) + attach_group_membership_providers(modules) diff --git a/src/dioptra/restapi/experiment/service.py b/src/dioptra/restapi/experiment/service.py index 7cca47c92..b9b001226 100644 --- a/src/dioptra/restapi/experiment/service.py +++ b/src/dioptra/restapi/experiment/service.py @@ -86,9 +86,9 @@ def create( def create_mlflow_experiment(self, experiment_name: str) -> int: try: - experiment_id: Optional[ - str - ] = self._mlflow_tracking_service.create_experiment(experiment_name) + experiment_id: Optional[str] = ( + self._mlflow_tracking_service.create_experiment(experiment_name) + ) except RestException as exc: raise ExperimentMLFlowTrackingRegistrationError from exc diff --git a/src/dioptra/restapi/group/__init__.py b/src/dioptra/restapi/group/__init__.py new file mode 100644 index 000000000..543416304 --- /dev/null +++ b/src/dioptra/restapi/group/__init__.py @@ -0,0 +1,28 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The group endpoint subpackage.""" + +from .dependencies import bind_dependencies, register_providers +from .errors import register_error_handlers +from .routes import register_routes + +__all__ = [ + "bind_dependencies", + "register_error_handlers", + "register_providers", + "register_routes", +] diff --git a/src/dioptra/restapi/group/controller.py b/src/dioptra/restapi/group/controller.py new file mode 100644 index 000000000..2f976bd07 --- /dev/null +++ b/src/dioptra/restapi/group/controller.py @@ -0,0 +1,112 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The module defining the group endpoints.""" +from __future__ import annotations + +import uuid +from typing import Any, cast + +import structlog +from flask import request +from flask_accepts import accepts, responds +from flask_restx import Namespace, Resource +from injector import inject +from structlog.stdlib import BoundLogger + +from .model import Group +from .schema import GroupSchema, IdStatusResponseSchema +from .service import GroupService + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +api: Namespace = Namespace( + "Group", + description="Group submission and management operations", +) + + +@api.route("/") +class GroupResource(Resource): + """Shows a list of all Group, and lets you POST to create new groups.""" + + @inject + def __init__( + self, + *args, + group_service: GroupService, + **kwargs, + ) -> None: + self._group_service = group_service + super().__init__(*args, **kwargs) + + @responds(schema=GroupSchema(many=True), api=api) + def get(self) -> list[Group]: + """Gets a list of all groups.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="group", request_type="GET" + ) # noqa: F841 + log.info("Request received") + return self._group_service.get_all(log=log) + + @accepts(schema=GroupSchema, api=api) + @responds(schema=GroupSchema, api=api) + def post(self) -> Group: + """Creates a new Group via a group submission form with an attached file.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="group", request_type="POST" + ) # noqa: F841 + + log.info("Request received") + + parsed_obj = request.parsed_obj # type: ignore + name = str(parsed_obj["group_name"]) + return self._group_service.create(name=name, log=log) + + @accepts(schema=GroupSchema, api=api) + @responds(schema=IdStatusResponseSchema, api=api) + def delete(self) -> dict[str, Any]: + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="group", request_type="POST" + ) # noqa: F841 + + log.info("Request received") + + parsed_obj = request.parsed_obj # type: ignore + group_id = int(parsed_obj["id"]) + return self._group_service.delete(id=group_id) + + +@api.route("/") +@api.param("groupId", "A string specifying a group's UUID.") +class GroupIdResource(Resource): + """Shows a single job.""" + + @inject + def __init__(self, *args, _service: GroupService, **kwargs) -> None: + self._group_service = _service + super().__init__(*args, **kwargs) + + @responds(schema=GroupSchema, api=api) + def get(self, groupId: int) -> Group: + """Gets a group by its unique identifier.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="groupId", request_type="GET" + ) # noqa: F841 + log.info("Request received", group_id=groupId) + group = self._group_service.get(groupId, error_if_not_found=True, log=log) + + return cast(Group, group) diff --git a/src/dioptra/restapi/group/dependencies.py b/src/dioptra/restapi/group/dependencies.py new file mode 100644 index 000000000..a515d71a9 --- /dev/null +++ b/src/dioptra/restapi/group/dependencies.py @@ -0,0 +1,51 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Binding configurations to shared services using dependency injection.""" +from __future__ import annotations + +from typing import Any, Callable + +from injector import Binder, Module, provider + +from .service import GroupService + + +class GroupServiceModule(Module): + @provider + def provide_queue_name_service_module( + self, + ) -> GroupService: + return GroupService() + + +def bind_dependencies(binder: Binder) -> None: + """Binds interfaces to implementations within the main application. + + Args: + binder: A :py:class:`~injector.Binder` object. + """ + pass + + +def register_providers(modules: list[Callable[..., Any]]) -> None: + """Registers type providers within the main application. + + Args: + modules: A list of callables used for configuring the dependency injection + environment. + """ + modules.append(GroupServiceModule) diff --git a/src/dioptra/restapi/group/errors.py b/src/dioptra/restapi/group/errors.py new file mode 100644 index 000000000..32b98c4dc --- /dev/null +++ b/src/dioptra/restapi/group/errors.py @@ -0,0 +1,44 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Error handlers for the group endpoints.""" +from __future__ import annotations + +from flask_restx import Api + + +class GroupDoesNotExistError(Exception): + """The requested group does not exist.""" + + +class GroupSubmissionError(Exception): + """The Group submission form contains invalid parameters.""" + + +def register_error_handlers(api: Api) -> None: + @api.errorhandler(GroupDoesNotExistError) + def handle_job_does_not_exist_error(error): + return {"message": "Not Found - The requested group does not exist"}, 404 + + @api.errorhandler(GroupSubmissionError) + def handle_job_submission_error(error): + return ( + { + "message": "Bad Request - The group submission form contains " + "invalid parameters. Please verify and resubmit." + }, + 400, + ) diff --git a/src/dioptra/restapi/group/model.py b/src/dioptra/restapi/group/model.py new file mode 100644 index 000000000..1bbb2ba10 --- /dev/null +++ b/src/dioptra/restapi/group/model.py @@ -0,0 +1,99 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The data models for the job endpoint objects.""" +from __future__ import annotations + +from dioptra.restapi.app import db +from dioptra.restapi.group_membership.model import GroupMembership +from dioptra.restapi.user.model import User + + +class Group(db.Model): + """The Groups table. + + Attributes: + group_id: The unique identifier of the group. + name: Human-readable name for the group. + creator_id: The id for the user that created the group. + owner_id: The id for the user that owns the group. + created_on: The time at which the group was created. + deleted: Whether the group has been deleted. + """ + + __tablename__ = "groups" + + group_id = db.Column(db.BigInteger(), primary_key=True) + name = db.Column(db.String(36)) + + creator_id = db.Column(db.BigInteger(), db.ForeignKey("users.user_id"), index=True) + owner_id = db.Column(db.BigInteger(), db.ForeignKey("users.user_id"), index=True) + + created_on = db.Column(db.DateTime()) + deleted = db.Column(db.Boolean) + + creator = db.relationship("User", foreign_keys=[creator_id]) + owner = db.relationship("User", foreign_keys=[owner_id]) + + @classmethod + def next_id(cls) -> int: + """Generates the next id in the sequence.""" + group: Group | None = cls.query.order_by(cls.group_id.desc()).first() + + if group is None: + return 1 + + return int(group.id) + 1 + + @property + def users(self): + """The users that are members of the group.""" + return ( + User.query.join(GroupMembership) + .filter(GroupMembership.group_id == self.group_id) + .all() + ) + + def check_membership(self, user: User) -> bool: + """Check if the user has permission to perform the specified action. + + Args: + user: The user to check. + action: The action to check. + + Returns: + True if the user has permission to perform the action, False otherwise. + """ + membership = GroupMembership.query.filter_by( + GroupMembership.user_id == user.user_id, + GroupMembership.group_id == self.group_id, + ) + + if membership is None: + return False + else: + return True + + def update(self, changes: dict): + """Updates the record. + + Args: + changes: A dictionary containing record updates. + """ + for key, val in changes.items(): + setattr(self, key, val) + + return self diff --git a/src/dioptra/restapi/group/routes.py b/src/dioptra/restapi/group/routes.py new file mode 100644 index 000000000..2786cd476 --- /dev/null +++ b/src/dioptra/restapi/group/routes.py @@ -0,0 +1,41 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Methods for registering the group endpoint routes with the main application. + +.. |Api| replace:: :py:class:`flask_restx.Api` +.. |Flask| replace:: :py:class:`flask.Flask` +""" +from __future__ import annotations + +from flask import Flask +from flask_restx import Api + +BASE_ROUTE: str = "group" + + +def register_routes(api: Api, app: Flask, root: str = "api") -> None: + """Registers the job endpoint routes with the main application. + + Args: + api: The main REST |Api| object. + app: The main |Flask| application. + root: The root path for the registration prefix of the namespace. The default + is `"api"`. + """ + from .controller import api as endpoint_api + + api.add_namespace(endpoint_api, path=f"/{root}/{BASE_ROUTE}") diff --git a/src/dioptra/restapi/group/schema.py b/src/dioptra/restapi/group/schema.py new file mode 100644 index 000000000..a109a8fa1 --- /dev/null +++ b/src/dioptra/restapi/group/schema.py @@ -0,0 +1,76 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The schemas for serializing/deserializing the job endpoint objects. + +.. |Job| replace:: :py:class:`~.model.Job` +.. |JobForm| replace:: :py:class:`~.model.JobForm` +.. |JobFormData| replace:: :py:class:`~.model.JobFormData` +""" +from __future__ import annotations + +from marshmallow import Schema, fields + + +class GroupSchema(Schema): + """The schema for the data stored in a |Group| object.""" + + group_id = fields.Integer( + attribute="id", metadata=dict(description="A UUID that identifies the group.") + ) + name = fields.String( + attribute="name", + allow_none=True, # should we force the user to pick a name? + metadata=dict( + description="Human-readable name for the group.", + ), + ) + creator_id = fields.Integer( + attribute="creator_id", + metadata=dict( + description="An integer identifying" "the user that created the group." + ), + ) + owner_id = fields.Integer( + attribute="owner_id", + metadata=dict( + description="An integer identifying the user that owns" "the group." + ), + ) + createdOn = fields.DateTime( + attribute="created_on", + metadata=dict(description="The date and time the group was created."), + ) + deleted = fields.Boolean( + attribute="deleted", + metadata=dict(description="Whether the group has been deleted."), + ) + + +class IdStatusResponseSchema(Schema): + """A simple response for reporting a status for one or more objects.""" + + status = fields.String( + attribute="status", + metadata=dict(description="The status of the request."), + ) + id = fields.List( + fields.Integer(), + attribute="id", + metadata=dict( + description="A list of integers identifying the affected object(s)." + ), + ) diff --git a/src/dioptra/restapi/group/service.py b/src/dioptra/restapi/group/service.py new file mode 100644 index 000000000..626244399 --- /dev/null +++ b/src/dioptra/restapi/group/service.py @@ -0,0 +1,133 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The server-side functions that perform job endpoint operations.""" +from __future__ import annotations + +import datetime +from typing import Any, cast + +import structlog +from sqlalchemy.exc import IntegrityError +from structlog.stdlib import BoundLogger + +from dioptra.restapi.app import db +from dioptra.restapi.utils import slugify + +from .errors import GroupDoesNotExistError +from .model import Group + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + + +class GroupService(object): + """The service methods for registering and managing groups by their unique id.""" + + def get_all(self, **kwargs) -> list[Group]: + """Fetch the list of all groups. + + Returns: + A list of groups. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + log.info("Get full list of groups.") + + return Group.query.all() # type: ignore + + def get( + self, group_id: int, error_if_not_found: bool = False, **kwargs + ) -> Group | None: + """Fetch a group by its unique id. + + Args: + group_id: The unique id of the group. + error_if_not_found: Raise an error if the group cannot be found. + Returns: + The group object if found, otherwise None. + + Raises: + GroupDoesNotExistError: If the group with group_id cannot be found. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + log.info("Get group by id", group_id=group_id) + group = Group.query.filter_by(group_id=group_id, deleted=False).first() + + if group is None: + if error_if_not_found: + log.error("Group not found", group_id=group_id) + raise GroupDoesNotExistError + return None + + return cast(Group, group) + + def create(self, name: str, user_id=None, **kwargs) -> Group: + """Create a new group. + + Args: + name: The name of the group. + user_id: The id of the user creating the group. + + Returns: + The newly created group object. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + timestamp = datetime.datetime.now() + + # #to be used when user is fully implemented + # if user_id is None: + # user_id= current_user.id + + new_group = Group( + group_id=Group.next_id(), + name=slugify(name), + creator_id=user_id, + owner_id=user_id, + created_on=timestamp, + deleted=False, + ) + + db.session.add(new_group) + db.session.commit() + + log.info("Group submission successful", group_id=new_group.group_id) + + return new_group + + def delete(self, id: int, **kwargs) -> dict[str, Any]: + """Delete a group. + + Args: + group_id: The unique id of the group. + + Returns: + A dictionary reporting the status of the request. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + if (group := self.get(id, log=log)) is None: + return {"status": "Success", "id": []} + group.update(changes={"deleted": True}) + try: + db.session.commit() + + log.info("Group deleted", group_id=id) + return {"status": "Success", "id": [id]} + except IntegrityError: + db.session.rollback() + return {"status": "Failure", "id": [id]} diff --git a/src/dioptra/restapi/group_membership/__init__.py b/src/dioptra/restapi/group_membership/__init__.py new file mode 100644 index 000000000..da65ce662 --- /dev/null +++ b/src/dioptra/restapi/group_membership/__init__.py @@ -0,0 +1,28 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The group membership endpoint subpackage.""" + +from .dependencies import bind_dependencies, register_providers +from .errors import register_error_handlers +from .routes import register_routes + +__all__ = [ + "bind_dependencies", + "register_error_handlers", + "register_providers", + "register_routes", +] diff --git a/src/dioptra/restapi/group_membership/controller.py b/src/dioptra/restapi/group_membership/controller.py new file mode 100644 index 000000000..fd0a489a0 --- /dev/null +++ b/src/dioptra/restapi/group_membership/controller.py @@ -0,0 +1,106 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The module defining the group membership endpoints.""" +from __future__ import annotations + +import uuid +from typing import Any + +import structlog +from flask import request +from flask_accepts import accepts, responds +from flask_restx import Namespace, Resource +from injector import inject +from structlog.stdlib import BoundLogger + +from .model import GroupMembership +from .schema import GroupMembershipSchema, IdStatusResponseSchema +from .service import GroupMembershipService + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +api: Namespace = Namespace( + "GroupMembership", + description="Add users to groups", +) + + +@api.route("/") +class GroupMembershipResource(Resource): + """Manage group memberships.""" + + @inject + def __init__( + self, + *args, + group_membership_service: GroupMembershipService, + **kwargs, + ) -> None: + self._group_membership_service = group_membership_service + super().__init__(*args, **kwargs) + + @responds(schema=GroupMembershipSchema(many=True), api=api) + def get(self) -> list[GroupMembership]: + """Get a list of all group memberships.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="group_membership", + request_type="GET", + ) + log.info("Request received") + return self._group_membership_service.get_all(log=log) + + @accepts(schema=GroupMembershipSchema, api=api) + @responds(schema=GroupMembershipSchema, api=api) + def post(self) -> GroupMembership: + """Create a new group membership using a group membership submission form.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="group_membership", + request_type="POST", + ) + + log.info("Request received") + + parsed_obj = request.parsed_obj # type: ignore + group_id = int(parsed_obj["group_id"]) + user_id = int(parsed_obj["user_id"]) + read = bool(parsed_obj["read"]) + write = bool(parsed_obj["write"]) + share_read = bool(parsed_obj["share_read"]) + share_write = bool(parsed_obj["share_write"]) + + return self._group_membership_service.create( + group_id, user_id, read, write, share_read, share_write, log=log + ) + + @accepts(schema=GroupMembershipSchema, api=api) + @responds(schema=IdStatusResponseSchema, api=api) + def delete(self) -> dict[str, Any]: + """Delete a group membership.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="group_membership", + request_type="DELETE", + ) + + log.info("Request received") + + parsed_obj = request.parsed_obj # type: ignore + group_id = int(parsed_obj["group_id"]) + user_id = int(parsed_obj["user_id"]) + return self._group_membership_service.delete(group_id, user_id, log=log) diff --git a/src/dioptra/restapi/group_membership/dependencies.py b/src/dioptra/restapi/group_membership/dependencies.py new file mode 100644 index 000000000..bf5e5428d --- /dev/null +++ b/src/dioptra/restapi/group_membership/dependencies.py @@ -0,0 +1,51 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Binding configurations to shared services using dependency injection.""" +from __future__ import annotations + +from typing import Any, Callable + +from injector import Binder, Module, provider + +from .service import GroupMembershipService + + +class GroupMembershipServiceModule(Module): + @provider + def provide_queue_name_service_module( + self, + ) -> GroupMembershipService: + return GroupMembershipService() + + +def bind_dependencies(binder: Binder) -> None: + """Binds interfaces to implementations within the main application. + + Args: + binder: A :py:class:`~injector.Binder` object. + """ + pass + + +def register_providers(modules: list[Callable[..., Any]]) -> None: + """Registers type providers within the main application. + + Args: + modules: A list of callables used for configuring the dependency injection + environment. + """ + modules.append(GroupMembershipServiceModule) diff --git a/src/dioptra/restapi/group_membership/errors.py b/src/dioptra/restapi/group_membership/errors.py new file mode 100644 index 000000000..626c81d34 --- /dev/null +++ b/src/dioptra/restapi/group_membership/errors.py @@ -0,0 +1,54 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Error handlers for the group membership endpoints.""" +from __future__ import annotations + +from flask_restx import Api + + +class GroupMembershipDoesNotExistError(Exception): + """The requested group membership does not exist.""" + + +class GroupMembershipSubmissionError(Exception): + """The group membership submission form contains invalid parameters.""" + + +class GroupMembershipAlreadyExistsError(Exception): + """The group membership submission form contains invalid parameters.""" + + +def register_error_handlers(api: Api) -> None: + @api.errorhandler(GroupMembershipDoesNotExistError) + def handle_group_membership_does_not_exist_error(error): + return { + "message": "Not Found - The requested group membership does not exist" + }, 404 + + @api.errorhandler(GroupMembershipAlreadyExistsError) + def handle_group_membership_already_exists_error(error): + return {"message": "Bad Request - The group membership already exists"}, 400 + + @api.errorhandler(GroupMembershipSubmissionError) + def handle_group_membership_submission_error(error): + return ( + { + "message": "Bad Request - The group membership submission form contains" + "invalid parameters. Please verify and resubmit." + }, + 400, + ) diff --git a/src/dioptra/restapi/group_membership/model.py b/src/dioptra/restapi/group_membership/model.py new file mode 100644 index 000000000..7c914905e --- /dev/null +++ b/src/dioptra/restapi/group_membership/model.py @@ -0,0 +1,67 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The data models for the job endpoint objects.""" +from __future__ import annotations + +from dioptra.restapi.app import db + + +class GroupMembership(db.Model): + """The group memberships table. + + Attributes: + user_id (int): The ID of the user who is a member of the group. + group_id (str): The ID of the group to which the user belongs. + read (bool): Indicates whether the user has read permissions in the group. + write (bool): Indicates whether the user has write permissions in the group. + share_read (bool): Indicates whether the user can share read permissions with + others in the group. + share_write (bool): Indicates whether the user can share write permissions + with others in the group. + """ + + __tablename__ = "group_memberships" + + user_id = db.Column( + db.BigInteger(), db.ForeignKey("users.user_id"), primary_key=True + ) + group_id = db.Column( + db.BigInteger(), db.ForeignKey("groups.group_id"), primary_key=True + ) + + read = db.Column(db.Boolean, default=False) + write = db.Column(db.Boolean, default=False) + share_read = db.Column(db.Boolean, default=False) + share_write = db.Column(db.Boolean, default=False) + + # is back populates needed? + user = db.relationship( + "User", + foreign_keys=[user_id], + ) + group = db.relationship("Group", foreign_keys=[group_id]) + + def update(self, changes: dict): + """Updates the record. + + Args: + changes: A dictionary containing record updates. + """ + for key, val in changes.items(): + setattr(self, key, val) + + return self diff --git a/src/dioptra/restapi/group_membership/routes.py b/src/dioptra/restapi/group_membership/routes.py new file mode 100644 index 000000000..870ce937d --- /dev/null +++ b/src/dioptra/restapi/group_membership/routes.py @@ -0,0 +1,42 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Methods for registering the group membership endpoint routes with the main +application. + +.. |Api| replace:: :py:class:`flask_restx.Api` +.. |Flask| replace:: :py:class:`flask.Flask` +""" +from __future__ import annotations + +from flask import Flask +from flask_restx import Api + +BASE_ROUTE: str = "group_membership" + + +def register_routes(api: Api, app: Flask, root: str = "api") -> None: + """Registers the job endpoint routes with the main application. + + Args: + api: The main REST |Api| object. + app: The main |Flask| application. + root: The root path for the registration prefix of the namespace. The default + is `"api"`. + """ + from .controller import api as endpoint_api + + api.add_namespace(endpoint_api, path=f"/{root}/{BASE_ROUTE}") diff --git a/src/dioptra/restapi/group_membership/schema.py b/src/dioptra/restapi/group_membership/schema.py new file mode 100644 index 000000000..c7e322d39 --- /dev/null +++ b/src/dioptra/restapi/group_membership/schema.py @@ -0,0 +1,80 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The schemas for serializing/deserializing the job endpoint objects. + +.. |Job| replace:: :py:class:`~.model.Job` +.. |JobForm| replace:: :py:class:`~.model.JobForm` +.. |JobFormData| replace:: :py:class:`~.model.JobFormData` +""" +from __future__ import annotations + +from marshmallow import Schema, fields + + +class GroupMembershipSchema(Schema): + """The schema for the data stored in a |GroupMembership| object.""" + + user_id = fields.Integer( + attribute="user_id", + metadata=dict(description="The ID of the user who is a member of the group."), + ) + group_id = fields.String( + attribute="group_id", + metadata=dict(description="The ID of the group to which the user belongs."), + ) + read = fields.Boolean( + attribute="read", + metadata=dict( + description="Indicates whether the user has read permissions in the group." + ), + ) + write = fields.Boolean( + attribute="write", + metadata=dict( + description="Indicates whether the user has write permissions in the group." + ), + ) + share_read = fields.Boolean( + attribute="share_read", + metadata=dict( + description="Indicates whether the user can share read permissions with \ + others in the group." + ), + ) + share_write = fields.Boolean( + attribute="share_write", + metadata=dict( + description="Indicates whether the user can share write permissions with \ + others in the group." + ), + ) + + +class IdStatusResponseSchema(Schema): + """A simple response for reporting a status for one or more objects.""" + + status = fields.String( + attribute="status", + metadata=dict(description="The status of the request."), + ) + id = fields.List( + fields.Integer(), + attribute="id", + metadata=dict( + description="A list of integers identifying the affected object(s)." + ), + ) diff --git a/src/dioptra/restapi/group_membership/service.py b/src/dioptra/restapi/group_membership/service.py new file mode 100644 index 000000000..b984c46e3 --- /dev/null +++ b/src/dioptra/restapi/group_membership/service.py @@ -0,0 +1,155 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The server-side functions that perform group membership endpoint operations.""" +from __future__ import annotations + +from typing import Any, cast + +import structlog +from sqlalchemy.exc import IntegrityError +from structlog.stdlib import BoundLogger + +from dioptra.restapi.app import db + +from .errors import GroupMembershipDoesNotExistError +from .model import GroupMembership + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + + +class GroupMembershipService(object): + def get_all(self, **kwargs) -> list[GroupMembership]: + """Retrieve a list of all group memberships. + + Returns: + List of group memberships. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + return GroupMembership.query.all() # type: ignore + + def get( + self, group_id: int, user_id: int, error_if_not_found: bool = False, **kwargs + ) -> GroupMembership | None: + """Retrieve a group membership. + + Args: + group_id: The unique ID of the group. + user_id: The unique ID of the user. + error_if_not_found: Flag to raise an error if the membership is + not found. + + Returns: + The group membership if found, else None. + + Raises: + GroupMembershipNotFoundError: If the membership is not found and + error_if_not_found is True. + """ + + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + membership = GroupMembership.query.filter( + GroupMembership.user_id == user_id, GroupMembership.group_id == group_id + ).first() + + if error_if_not_found: + if membership is None: + log.error("Group Membership not found", group_id=group_id) + raise GroupMembershipDoesNotExistError + + return cast(GroupMembership, membership) + + def create( + self, + group_id: int, + user_id: int, + read: bool, + write: bool, + share_read: bool, + share_write: bool, + **kwargs, + ) -> GroupMembership: + """Create a new group membership. + + Args: + group_id: The unique ID of the group. + user_id: The unique ID of the user. + read: Permission flag for read access. + write: Permission flag for write access. + share_read: Permission flag for sharing with read access. + share_write: Permission flag for sharing with write access. + + Returns: + GroupMembership: The newly created group membership. + + Raises: + GroupMembershipAlreadyExistsError: If the membership already exists. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + if self.get(group_id=group_id, user_id=user_id) is not None: + log.error("Group Membership already exists", group_id=group_id) + raise GroupMembershipDoesNotExistError + + new_group_membership = GroupMembership( + group_id=group_id, + user_id=user_id, + read=read, + write=write, + share_read=share_read, + share_write=share_write, + ) + + db.session.add(new_group_membership) + db.session.commit() + + log.info( + "Group Membership submission successful", + group_id=new_group_membership.group_id, + user_id=new_group_membership.user_id, + ) + + return new_group_membership + + def delete(self, group_id, user_id, **kwargs) -> dict[str, Any]: + """Delete a group membership. + + Args: + group_id: The unique ID of the group. + user_id: The unique ID of the user. + + Returns: + A dictionary with the status and IDs of the deleted membership. + Raises: + IntegrityError: If there is an issue with the database integrity during + deletion. + """ + + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + if (membership := self.get(group_id=group_id, user_id=user_id)) is None: + return {"status": "Success", "id": []} + try: + db.session.delete(membership) + db.session.commit() + + log.info("Group Membership deleted", group_id=group_id, user_id=user_id) + return {"status": "Success", "id": [group_id, user_id]} + except IntegrityError: + db.session.rollback() + return {"status": "Failure", "id": [group_id, user_id]} diff --git a/src/dioptra/restapi/models.py b/src/dioptra/restapi/models.py index 925b9c928..bca58e775 100644 --- a/src/dioptra/restapi/models.py +++ b/src/dioptra/restapi/models.py @@ -22,6 +22,8 @@ ExperimentRegistrationForm, ExperimentRegistrationFormData, ) +from .group.model import Group +from .group_membership.model import GroupMembership from .job.model import Job, JobForm, JobFormData from .queue.model import Queue, QueueLock from .task_plugin.model import ( @@ -44,4 +46,6 @@ "TaskPluginUploadForm", "TaskPluginUploadFormData", "User", + "Group", + "GroupMembership", ] diff --git a/src/dioptra/restapi/routes.py b/src/dioptra/restapi/routes.py index 1718ad720..736011923 100644 --- a/src/dioptra/restapi/routes.py +++ b/src/dioptra/restapi/routes.py @@ -34,6 +34,8 @@ def register_routes(api: Api, app: Flask) -> None: """ from .auth import register_routes as attach_auth from .experiment import register_routes as attach_experiment + from .group import register_routes as attach_group + from .group_membership import register_routes as attach_group_membership from .job import register_routes as attach_job from .queue import register_routes as attach_job_queue from .task_plugin import register_routes as attach_task_plugin @@ -45,3 +47,5 @@ def register_routes(api: Api, app: Flask) -> None: attach_job_queue(api, app) attach_task_plugin(api, app) attach_user(api, app) + attach_group(api, app) + attach_group_membership(api, app) diff --git a/src/dioptra/restapi/task_plugin/controller.py b/src/dioptra/restapi/task_plugin/controller.py index 7e2b5003d..ebbaa529a 100644 --- a/src/dioptra/restapi/task_plugin/controller.py +++ b/src/dioptra/restapi/task_plugin/controller.py @@ -148,13 +148,13 @@ def get(self, taskPluginName: str) -> TaskPlugin: ) log.info("Request received") - task_plugin: Optional[ - TaskPlugin - ] = self._task_plugin_service.get_by_name_in_collection( - collection="dioptra_builtins", - task_plugin_name=taskPluginName, - bucket=current_app.config["DIOPTRA_PLUGINS_BUCKET"], - log=log, + task_plugin: Optional[TaskPlugin] = ( + self._task_plugin_service.get_by_name_in_collection( + collection="dioptra_builtins", + task_plugin_name=taskPluginName, + bucket=current_app.config["DIOPTRA_PLUGINS_BUCKET"], + log=log, + ) ) if task_plugin is None: @@ -219,13 +219,13 @@ def get(self, taskPluginName: str) -> TaskPlugin: ) log.info("Request received") - task_plugin: Optional[ - TaskPlugin - ] = self._task_plugin_service.get_by_name_in_collection( - collection="dioptra_custom", - task_plugin_name=taskPluginName, - bucket=current_app.config["DIOPTRA_PLUGINS_BUCKET"], - log=log, + task_plugin: Optional[TaskPlugin] = ( + self._task_plugin_service.get_by_name_in_collection( + collection="dioptra_custom", + task_plugin_name=taskPluginName, + bucket=current_app.config["DIOPTRA_PLUGINS_BUCKET"], + log=log, + ) ) if task_plugin is None: diff --git a/src/dioptra/restapi/utils.py b/src/dioptra/restapi/utils.py index 52026cb8c..faa68b6e6 100644 --- a/src/dioptra/restapi/utils.py +++ b/src/dioptra/restapi/utils.py @@ -94,8 +94,7 @@ class _ClassBasedViewFunction(Protocol): view_class: Type[View] - def __call__(self, *args, **kwargs) -> Any: - ... + def __call__(self, *args, **kwargs) -> Any: ... # noqa E704 def _new_class_view_function( diff --git a/src/dioptra/sdk/object_detection/bounding_boxes/postprocessing/tensorflow_backend/confluence.py b/src/dioptra/sdk/object_detection/bounding_boxes/postprocessing/tensorflow_backend/confluence.py index 36b587bc3..654b8ef6f 100644 --- a/src/dioptra/sdk/object_detection/bounding_boxes/postprocessing/tensorflow_backend/confluence.py +++ b/src/dioptra/sdk/object_detection/bounding_boxes/postprocessing/tensorflow_backend/confluence.py @@ -259,9 +259,9 @@ def confluence( all_proximities = np.ones_like(proximity) cconf_scores = np.zeros_like(cconf) - all_proximities[ - proximity <= self._confluence_threshold - ] = proximity[proximity <= self._confluence_threshold] + all_proximities[proximity <= self._confluence_threshold] = ( + proximity[proximity <= self._confluence_threshold] + ) cconf_scores[proximity <= self._confluence_threshold] = cconf[ proximity <= self._confluence_threshold ] @@ -334,18 +334,18 @@ def confluence( -((1 - manhattan_distance) * (1 - manhattan_distance)) / self._sigma ) - weights[ - manhattan_distance <= self._confluence_threshold - ] = gaussian_weights[ - manhattan_distance <= self._confluence_threshold - ] + weights[manhattan_distance <= self._confluence_threshold] = ( + gaussian_weights[ + manhattan_distance <= self._confluence_threshold + ] + ) else: - weights[ - manhattan_distance <= self._confluence_threshold - ] = manhattan_distance[ - manhattan_distance <= self._confluence_threshold - ] + weights[manhattan_distance <= self._confluence_threshold] = ( + manhattan_distance[ + manhattan_distance <= self._confluence_threshold + ] + ) dets[1:, 4] *= weights to_reprocess = np.where(dets[1:, 4] >= self._score_threshold)[0] @@ -506,14 +506,14 @@ def pad_retained_arrays( ) if len(batch_scores) > 0: - padded_scores[ - batch_idx, 0 : batch_scores.shape[0] - ] = batch_scores.astype("float32") + padded_scores[batch_idx, 0 : batch_scores.shape[0]] = ( + batch_scores.astype("float32") + ) if len(batch_labels) > 0: - padded_labels[ - batch_idx, 0 : batch_labels.shape[0] - ] = batch_labels.astype("int32") + padded_labels[batch_idx, 0 : batch_labels.shape[0]] = ( + batch_labels.astype("int32") + ) padded_detections[batch_idx] = len(batch_scores) diff --git a/src/dioptra/sdk/object_detection/data/tensorflow_backend.py b/src/dioptra/sdk/object_detection/data/tensorflow_backend.py index d24339355..16ba2cc78 100644 --- a/src/dioptra/sdk/object_detection/data/tensorflow_backend.py +++ b/src/dioptra/sdk/object_detection/data/tensorflow_backend.py @@ -115,15 +115,15 @@ def create( augmentations_seed: Optional[int] = None, shuffle_seed: Optional[int] = None, ) -> TensorflowObjectDetectionData: - annotation_data_registry: dict[ - str, Callable[[], PascalVOCAnnotationData] - ] = dict( - pascal_voc=lambda: PascalVOCAnnotationData( - labels=labels, - encoding=NumpyAnnotationEncoding( - boxes_dtype="float32", labels_dtype="int32" + annotation_data_registry: dict[str, Callable[[], PascalVOCAnnotationData]] = ( + dict( + pascal_voc=lambda: PascalVOCAnnotationData( + labels=labels, + encoding=NumpyAnnotationEncoding( + boxes_dtype="float32", labels_dtype="int32" + ), ), - ), + ) ) augmentations_registry: dict[ str, Callable[[], ImgAugObjectDetectionAugmentations] diff --git a/src/dioptra/task_engine/task_engine.py b/src/dioptra/task_engine/task_engine.py index a8e0b5db1..7a59092cb 100644 --- a/src/dioptra/task_engine/task_engine.py +++ b/src/dioptra/task_engine/task_engine.py @@ -468,9 +468,9 @@ def _run_experiment( ) log.debug("Global parameters:\n %s", props_values) - step_outputs: MutableMapping[ - str, MutableMapping[str, Any] - ] = collections.defaultdict(dict) + step_outputs: MutableMapping[str, MutableMapping[str, Any]] = ( + collections.defaultdict(dict) + ) step_order = util.get_sorted_steps(graph) diff --git a/task-plugins/dioptra_builtins/metrics/distance.py b/task-plugins/dioptra_builtins/metrics/distance.py index 1c74f8692..a737ba2df 100644 --- a/task-plugins/dioptra_builtins/metrics/distance.py +++ b/task-plugins/dioptra_builtins/metrics/distance.py @@ -68,9 +68,9 @@ def get_distance_metric_list( distance_metrics_list: List[Tuple[str, Callable[..., np.ndarray]]] = [] for metric in request: - metric_callable: Optional[ - Callable[..., np.ndarray] - ] = DISTANCE_METRICS_REGISTRY.get(metric["func"]) + metric_callable: Optional[Callable[..., np.ndarray]] = ( + DISTANCE_METRICS_REGISTRY.get(metric["func"]) + ) if metric_callable is not None: distance_metrics_list.append((metric["name"], metric_callable)) @@ -106,9 +106,9 @@ def get_distance_metric(func: str) -> Callable[..., np.ndarray]: Returns: A callable distance metric function. """ - metric_callable: Optional[ - Callable[..., np.ndarray] - ] = DISTANCE_METRICS_REGISTRY.get(func) + metric_callable: Optional[Callable[..., np.ndarray]] = ( + DISTANCE_METRICS_REGISTRY.get(func) + ) if metric_callable is None: LOGGER.error( diff --git a/task-plugins/dioptra_builtins/metrics/performance.py b/task-plugins/dioptra_builtins/metrics/performance.py index d66d36736..29d6147b4 100644 --- a/task-plugins/dioptra_builtins/metrics/performance.py +++ b/task-plugins/dioptra_builtins/metrics/performance.py @@ -69,9 +69,9 @@ def get_performance_metric_list( performance_metrics_list: List[Tuple[str, Callable[..., float]]] = [] for metric in request: - metric_callable: Optional[ - Callable[..., float] - ] = PERFORMANCE_METRICS_REGISTRY.get(metric["func"]) + metric_callable: Optional[Callable[..., float]] = ( + PERFORMANCE_METRICS_REGISTRY.get(metric["func"]) + ) if metric_callable is not None: performance_metrics_list.append((metric["name"], metric_callable)) diff --git a/tests/containers/mlflow_tracking/test_entrypoint.py b/tests/containers/mlflow_tracking/test_entrypoint.py index e7adb986f..9ef038cf7 100644 --- a/tests/containers/mlflow_tracking/test_entrypoint.py +++ b/tests/containers/mlflow_tracking/test_entrypoint.py @@ -54,7 +54,9 @@ def host(container: Container) -> Host: @pytest.fixture(scope="function") def print_db_tables_pyscript(container: Container, tmp_path: Path) -> str: - pyscript: str | bytes = """ + pyscript: ( + str | bytes + ) = """ import sqlite3\n con = sqlite3.connect("/work/mlruns/mlflow-tracking.db") diff --git a/tests/cookiecutter_dioptra_deployment/conftest.py b/tests/cookiecutter_dioptra_deployment/conftest.py index 72a3f9937..3e14e2449 100644 --- a/tests/cookiecutter_dioptra_deployment/conftest.py +++ b/tests/cookiecutter_dioptra_deployment/conftest.py @@ -117,7 +117,7 @@ def context(): "image": "node", "namespace": "", "tag": "latest", - "registry": "" + "registry": "", }, "redis": { "image": "redis", diff --git a/tests/unit/restapi/conftest.py b/tests/unit/restapi/conftest.py index dc8edb11c..41427fa69 100644 --- a/tests/unit/restapi/conftest.py +++ b/tests/unit/restapi/conftest.py @@ -58,9 +58,10 @@ def task_plugins_dir(tmp_path_factory): def workflow_tar_gz(): workflow_tar_gz_fileobj: BinaryIO = io.BytesIO() - with tarfile.open(fileobj=workflow_tar_gz_fileobj, mode="w:gz") as f, io.BytesIO( - initial_bytes=b"data" - ) as data: + with ( + tarfile.open(fileobj=workflow_tar_gz_fileobj, mode="w:gz") as f, + io.BytesIO(initial_bytes=b"data") as data, + ): tarinfo = tarfile.TarInfo(name="MLproject") tarinfo.size = len(data.getbuffer()) f.addfile(tarinfo=tarinfo, fileobj=data) @@ -74,9 +75,11 @@ def workflow_tar_gz(): def task_plugin_archive(): archive_fileobj: BinaryIO = io.BytesIO() - with tarfile.open(fileobj=archive_fileobj, mode="w:gz") as f, io.BytesIO( - initial_bytes=b"# init file" - ) as f_init, io.BytesIO(initial_bytes=b"# plugin module") as f_plugin_module: + with ( + tarfile.open(fileobj=archive_fileobj, mode="w:gz") as f, + io.BytesIO(initial_bytes=b"# init file") as f_init, + io.BytesIO(initial_bytes=b"# plugin module") as f_plugin_module, + ): tarinfo_init = tarfile.TarInfo(name="new_plugin_module/__init__.py") tarinfo_init.size = len(f_init.getbuffer()) f.addfile(tarinfo=tarinfo_init, fileobj=f_init) diff --git a/tests/unit/restapi/group/__init__.py b/tests/unit/restapi/group/__init__.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/tests/unit/restapi/group/__init__.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/tests/unit/restapi/group/test_group.py b/tests/unit/restapi/group/test_group.py new file mode 100644 index 000000000..7db91933d --- /dev/null +++ b/tests/unit/restapi/group/test_group.py @@ -0,0 +1,476 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Test suite for queue operations. + +This module contains a set of tests that validate the CRUD operations and additional +functionalities for the queue entity. The tests ensure that the queues can be +registered, renamed, deleted, and locked/unlocked as expected through the REST API. +""" +from __future__ import annotations + +import datetime +from typing import Any + +import pytest +from flask.testing import FlaskClient +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from werkzeug.test import TestResponse + +from dioptra.restapi.group.model import Group +from dioptra.restapi.group.service import GroupService +from dioptra.restapi.group_membership.model import GroupMembership +from dioptra.restapi.group_membership.service import GroupMembershipService +from dioptra.restapi.queue.routes import BASE_ROUTE as QUEUE_BASE_ROUTE +from dioptra.restapi.user.model import User + + +@pytest.fixture +def group_service() -> GroupService: + yield GroupService() + + +@pytest.fixture +def group_membership_service() -> GroupMembershipService: + yield GroupMembershipService() + + +###### helpers + + +def create_user(db: SQLAlchemy) -> User: + """Create a user and add them to the database. + + Args: + db: The SQLAlchemy database session. + + Returns: + The newly created user object. + """ + timestamp = datetime.datetime.now() + user_expire_on = datetime.datetime(9999, 12, 31, 23, 59, 59) + password_expire_on = timestamp.replace(year=timestamp.year + 1) + + new_user: User = User( + username="test_admin", + password="password", + email_address="test@test.com", + created_on=timestamp, + last_modified_on=timestamp, + last_login_on=timestamp, + password_expire_on=password_expire_on, + ) + db.session.add(new_user) + db.session.commit() + return new_user + + +def create_group(group_service: GroupService, name: str = "test") -> Group: + """Create a group using the group service. + + Args: + group_service: The group service responsible for group creation. + name: The name to assign to the new group (default is "test"). + + Returns: + The response from the group service representing the newly created group. + """ + return group_service.create(name) + + +def get_group(id: int, group_service: GroupService) -> Group | None: + """Retrieve a group by its unique identifier. + + Args: + id: The unique identifier of the group. + group_service: The service responsible for handling group-related operations. + + Returns: + The retrieved Group object. + + Raises: + GroupDoesNotExistError: If no group with the specified ID is found. + """ + return group_service.get(id) + + +def delete_group(id: int, group_service: GroupService) -> dict[str, Any]: + """Delete a group by its unique identifier. + + Args: + id: The unique identifier of the group. + group_service: The service responsible for handling group operations. + + Returns: + A dictionary reporting the status of the request. + + Raises: + GroupDoesNotExistError: If the group with the specified ID does not exist. + """ + return group_service.delete(id) + + +def create_group_membership( + group_id: int, + user_id: int, + read: bool, + write: bool, + share_read: bool, + share_write: bool, + group_membership_service: GroupMembershipService, +) -> GroupMembershipService | None: + """Create a group membership. + + Args: + group_id: The unique identifier of the group. + user_id: The unique identifier of the user. + read: Whether the user has read permissions. + write: Whether the user has write permissions. + share_read: Whether the user has share-read permissions. + share_write: Whether the user has share-write permissions. + group_membership_service: The service responsible for handling group memberships. + + Returns: + The created GroupMembership object representing the group membership. + + Raises: + GroupMembershipSubmissionError: If there is an issue with the submission. + """ + return group_membership_service.create( + group_id, + user_id, + read=read, + write=write, + share_read=share_read, + share_write=share_write, + ) + + +def get_group_membership( + user_id: int, group_id: int, group_membership_service: GroupMembershipService +) -> GroupMembership | None: + """Retrieve a group membership for a user in a specific group. + + Args: + user_id: The unique identifier of the user. + group_id: The unique identifier of the group. + group_membership_service: The service responsible for handling group membership operations. + + Returns: + The retrieved GroupMembership object if found, otherwise None. + + Raises: + GroupMembershipDoesNotExistError: If no group membership with the specified user and group IDs is found. + """ + return group_membership_service.get(group_id, user_id) + + +def delete_group_membership( + group_id: int, user_id: int, group_membership_service: Any +) -> dict[str, Any]: + """Delete a group membership. + + Args: + group_id: The unique identifier of the group. + user_id: The unique identifier of the user. + group_membership_service: The service responsible for group membership operations. + + Returns: + A dictionary reporting the status of the request. + """ + return group_membership_service.delete(group_id, user_id) + + +##### asserts + + +def assert_group_membership_created(membership: Any, group: Any, new_user: Any) -> None: + """Assert that the group membership has been created. + + Args: + membership: The created group membership object. + group: The group to which the user is added. + new_user: The user added to the group. + """ + assert membership.group_id == group.group_id + assert membership.user_id == new_user.user_id + assert membership.read is True + assert membership.write is True + assert membership.share_read is True + assert membership.share_write is True + + assert new_user in group.users + + +def assert_group_membership_does_not_exist( + membership: Any, group: Any, new_user: Any +) -> None: + """Assert that the group membership has not been created. + + Args: + membership: The created group membership object (should be None). + group: The group to which the user should not be added. + new_user: The user that should not be added to the group. + """ + assert membership is None + + +def assert_membership_group( + retrieved_membership: GroupMembership, + group: Group, +) -> None: + """Assert that the group in the retrieved membership matches the expected group. + + Args: + retrieved_membership: The retrieved membership object. + group: The expected group object. + + Raises: + AssertionError: If the groups do not match. + """ + assert retrieved_membership.group == group + + +def assert_group_in_list( + group: Group, + group_service: GroupService, +) -> None: + """Assert that the group is in the list of all groups retrieved from the service. + + Args: + group: The group object to check. + group_service: The group service. + + Raises: + AssertionError: If the group is not found in the list. + """ + assert group in group_service.get_all() + + +def assert_group_is_none(retrieved_group: Group | None) -> None: + """Assert that the retrieved group is None. + + Args: + retrieved_group: The retrieved group object. + + Raises: + AssertionError: If the retrieved group is not None. + """ + assert retrieved_group is None + + +def assert_membership_user_equals( + retrieved_membership: GroupMembership, new_user: User +) -> None: + """Assert that the user in the retrieved membership equals the new user. + + Args: + retrieved_membership: The retrieved group membership object. + new_user: The new user object. + + Raises: + AssertionError: If the user in the retrieved membership does not equal the new user. + """ + assert retrieved_membership.user == new_user + + +###### tests + + +def test_create_group(db: SQLAlchemy, group_service: GroupService): + """Test the creation of a group using the database and group service. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group creation. + """ + group = create_group(group_service, name="Test Group") + + assert_group_in_list(group, group_service) + + +def test_delete_group(db: SQLAlchemy, group_service: GroupService): + """Test the deletion of a group using the database and group service. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + """ + group = create_group(group_service, name="Test Group") + delete_group(group.group_id, group_service) + retrieved_group = get_group(group.group_id, group_service) + + assert_group_is_none(retrieved_group) + + +def test_create_group_membership( + db: SQLAlchemy, + group_service: GroupService, + group_membership_service: GroupMembershipService, +) -> None: + """Test the creation of a group membership using the database and services. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + group_membership_service: The group membership service responsible for group membership operations. + """ # Create a user + + new_user = create_user(db) + + # Create a group + group = create_group(group_service, name="Test Group") + + # Create a group membership + membership = create_group_membership( + group.group_id, + new_user.user_id, + read=True, + write=True, + share_read=True, + share_write=True, + group_membership_service=group_membership_service, + ) + + assert_group_membership_created(membership, group, new_user) + + +def test_delete_group_membership( + db: SQLAlchemy, + group_service: GroupService, + group_membership_service: GroupMembershipService, +) -> None: + """Test the deletion of a group membership using the database and services. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + group_membership_service: The group membership service responsible for group membership operations. + """ + # Create a user + + new_user = create_user(db) + + # Create a group + group = create_group(group_service, name="Test Group") + + # Create a group membership + membership = create_group_membership( + group.group_id, + new_user.user_id, + read=True, + write=True, + share_read=True, + share_write=True, + group_membership_service=group_membership_service, + ) + + # get membership from db + retrieved_membership = get_group_membership( + group.group_id, new_user.user_id, group_membership_service + ) + + # and then delete it + delete_group_membership( + retrieved_membership.group_id, + retrieved_membership.user_id, + group_membership_service, + ) + + # get membership from db + retrieved_membership = get_group_membership( + group.group_id, new_user.user_id, group_membership_service + ) + + assert_group_membership_does_not_exist(retrieved_membership, group, new_user) + + +def test_group_relationship( + db: SQLAlchemy, + group_service: GroupService, + group_membership_service: GroupMembershipService, +) -> None: + """Test the relationship between groups and group memberships using the database and services. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + group_membership_service: The group membership service responsible for group membership operations. + """ + # Create a user + + new_user = create_user(db) + + # Create a group + group = create_group(group_service, name="Test Group") + + # Create a group membership + membership = create_group_membership( + group.group_id, + new_user.user_id, + read=True, + write=True, + share_read=True, + share_write=True, + group_membership_service=group_membership_service, + ) + + # get membership from db + retrieved_membership = get_group_membership( + group.group_id, new_user.user_id, group_membership_service + ) + + assert_membership_group(retrieved_membership, group) + + +def test_user_relationship( + db: SQLAlchemy, + group_service: GroupService, + group_membership_service: GroupMembershipService, +) -> None: + """Test the relationship between users and group memberships using the database and services. + + Args: + db: The SQLAlchemy database session for testing. + group_service: The group service responsible for group operations. + group_membership_service: The group membership service responsible for group membership operations. + """ + # Create a user + + new_user = create_user(db) + + # Create a group + group = create_group(group_service, name="Test Group") + + # Create a group membership + membership = create_group_membership( + group.group_id, + new_user.user_id, + read=True, + write=True, + share_read=True, + share_write=True, + group_membership_service=group_membership_service, + ) + + # get membership from db + retrieved_membership = get_group_membership( + group.group_id, new_user.user_id, group_membership_service + ) + + assert_membership_user_equals(retrieved_membership, new_user) diff --git a/tests/unit/restapi/task_plugin/test_schema.py b/tests/unit/restapi/task_plugin/test_schema.py index 0758d014a..0bbc03968 100644 --- a/tests/unit/restapi/task_plugin/test_schema.py +++ b/tests/unit/restapi/task_plugin/test_schema.py @@ -111,9 +111,9 @@ def test_TaskPluginUploadFormSchema_dump_works( task_plugin_upload_form_schema: TaskPluginUploadFormSchema, task_plugin_archive: BinaryIO, ) -> None: - task_plugin_upload_form_serialized: Dict[ - str, Any - ] = task_plugin_upload_form_schema.dump(task_plugin_upload_form) + task_plugin_upload_form_serialized: Dict[str, Any] = ( + task_plugin_upload_form_schema.dump(task_plugin_upload_form) + ) assert task_plugin_upload_form_serialized["task_plugin_name"] == "new_plugin_one" assert task_plugin_upload_form_serialized["collection"] == "dioptra_custom" diff --git a/tests/unit/restapi/test_user.py b/tests/unit/restapi/test_user.py index a6cc95311..647ddb463 100644 --- a/tests/unit/restapi/test_user.py +++ b/tests/unit/restapi/test_user.py @@ -200,13 +200,13 @@ def _(client: RequestsSession) -> dict[str, Any]: @overload -def login(client: FlaskClient, username: str, password: str) -> TestResponse: - ... +def login(client: FlaskClient, username: str, password: str) -> TestResponse: ... @overload -def login(client: RequestsSession, username: str, password: str) -> RequestsResponse: - ... +def login( + client: RequestsSession, username: str, password: str +) -> RequestsResponse: ... @singledispatch @@ -244,13 +244,11 @@ def _(client: RequestsSession, username: str, password: str) -> RequestsResponse @overload -def logout(client: FlaskClient, everywhere: bool) -> TestResponse: - ... +def logout(client: FlaskClient, everywhere: bool) -> TestResponse: ... @overload -def logout(client: RequestsSession, everywhere: bool) -> RequestsResponse: - ... +def logout(client: RequestsSession, everywhere: bool) -> RequestsResponse: ... @singledispatch @@ -288,15 +286,13 @@ def _(client: RequestsSession, everywhere: bool) -> RequestsResponse: @overload def change_password( client: FlaskClient, user_id: int, current_password: str, new_password: str -) -> TestResponse: - ... +) -> TestResponse: ... @overload def change_password( client: RequestsSession, user_id: int, current_password: str, new_password: str -) -> RequestsResponse: - ... +) -> RequestsResponse: ... @singledispatch @@ -351,15 +347,13 @@ def _( @overload def change_current_user_password( client: FlaskClient, current_password: str, new_password: str -) -> TestResponse: - ... +) -> TestResponse: ... @overload def change_current_user_password( client: RequestsSession, current_password: str, new_password: str -) -> RequestsResponse: - ... +) -> RequestsResponse: ... @singledispatch @@ -407,13 +401,11 @@ def _( @overload -def delete_current_user(client: FlaskClient, password: str) -> TestResponse: - ... +def delete_current_user(client: FlaskClient, password: str) -> TestResponse: ... @overload -def delete_current_user(client: RequestsSession, password: str) -> RequestsResponse: - ... +def delete_current_user(client: RequestsSession, password: str) -> RequestsResponse: ... @singledispatch