From 20b7c749082f86beaf23e26c79f0df31248bc6fe Mon Sep 17 00:00:00 2001 From: andrew-uoa <142769327+andrew-uoa@users.noreply.github.com> Date: Thu, 15 Aug 2024 12:13:31 +1200 Subject: [PATCH] Stronger validation of responses from MyTardis API (#483) * Add a type for storing/validating an MD5 checksum, and associated tests * Introduce a base dataclass to be used for all the validation dataclasses representing MyTardis resources * Move hexadecimal validator to a new module specifically for validators * Extract two validation functions to the validators module * Rename validators module to avoid possible name clash with validators package * Delete unused regex * Use the ISO datetime parser from dateutils, instead of a hand-rolled regex * Extract MD5 checksum validation to a standalone function * Convert MD5Sum to an Annotated str to reduce verbosity * Move the definitions of some helper Annotated types to mytardis_client so they can be used in defining dataclasses there without introducing cyclic dependencies * Remove seemingly unnecessary boilerplate * Fix broken imports * Add dataclasses for validating/storing response data from the MyTardis API * Store the output data type associated with GET requests to each endpoint, and use this in the rest client GET calls * Use the more strongly typed GET interface in the overseer * Use the more strongly-typed GET interface in the Overseer * Introduce a MyTardisResource protocol to achieve structural subtyping in the return types from the MyTardis client * Simplify some type-checking logic * Delete commented code * Update the poetry lock file * Add package with typing stubs for the dateutil package * Fix broken tests * Avoid overriding default pagination from MyTardis if no pagination specified * Add missing f-string prefix * Fix broken Overseer tests * Rename modules to align terminology (validation instead of validate/validators) --- poetry.lock | 33 ++- pyproject.toml | 2 + src/blueprints/__init__.py | 1 - src/blueprints/custom_data_types.py | 48 ----- src/blueprints/datafile.py | 3 +- src/blueprints/dataset.py | 8 +- src/blueprints/experiment.py | 8 +- src/blueprints/project.py | 9 +- src/cli/cmd_clean.py | 15 +- src/config/config.py | 2 +- src/ingestion_factory/factory.py | 8 +- src/inspector/inspector.py | 13 +- src/mytardis_client/common_types.py | 22 +- src/mytardis_client/endpoint_info.py | 42 ++-- src/mytardis_client/mt_rest.py | 30 +-- src/mytardis_client/response_data.py | 222 ++++++++++++++++++++- src/overseers/overseer.py | 96 ++------- src/profiles/abi_music/parsing.py | 3 +- src/profiles/ro_crate/ro_crate_parser.py | 2 +- src/smelters/smelter.py | 2 +- src/utils/types/type_helpers.py | 16 ++ src/utils/validation.py | 54 +++++ tests/fixtures/fixtures_constants.py | 3 +- tests/fixtures/fixtures_dataclasses.py | 4 +- tests/fixtures/fixtures_responses.py | 6 +- tests/test_custom_data_types.py | 44 +--- tests/test_dataclasses.py | 2 +- tests/test_mytardis_client_rest_factory.py | 22 +- tests/test_overseers.py | 67 +++---- tests/test_utils_validation.py | 114 +++++++++++ 30 files changed, 609 insertions(+), 292 deletions(-) create mode 100644 src/utils/types/type_helpers.py create mode 100644 src/utils/validation.py create mode 100644 tests/test_utils_validation.py diff --git a/poetry.lock b/poetry.lock index 03f4ada6..fe293444 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -3452,6 +3452,24 @@ tinydb = ">=3.5.0" dev = ["Sphinx (==1.7.1)", "sphinx-autobuild (==2021.3.14)", "tox (>=2.3.1)"] test = ["aioresponses (>=0.6.2)", "coverage (>=4.2)", "parametrize (>=0.1.1)", "pytest (>=3.0.3)", "pytest-cov (>=2.3.1,<2.6)", "responses (>=0.5.1)"] +[[package]] +name = "typeguard" +version = "4.3.0" +description = "Run-time type checker for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typeguard-4.3.0-py3-none-any.whl", hash = "sha256:4d24c5b39a117f8a895b9da7a9b3114f04eb63bade45a4492de49b175b6f7dfa"}, + {file = "typeguard-4.3.0.tar.gz", hash = "sha256:92ee6a0aec9135181eae6067ebd617fd9de8d75d714fb548728a4933b1dea651"}, +] + +[package.dependencies] +typing-extensions = ">=4.10.0" + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.3.0)"] +test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"] + [[package]] name = "typer" version = "0.12.3" @@ -3480,6 +3498,17 @@ files = [ {file = "types_mock-5.1.0.20240425-py3-none-any.whl", hash = "sha256:d586a01d39ad919d3ddcd73de6cde73ca7f3c69707219f722d1b8d7733641ad7"}, ] +[[package]] +name = "types-python-dateutil" +version = "2.9.0.20240316" +description = "Typing stubs for python-dateutil" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-python-dateutil-2.9.0.20240316.tar.gz", hash = "sha256:5d2f2e240b86905e40944dd787db6da9263f0deabef1076ddaed797351ec0202"}, + {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, +] + [[package]] name = "types-python-slugify" version = "8.0.2.20240310" @@ -3734,4 +3763,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "9bd0326450dcca271d8cfe2598e3bb447a71f6933abb49143262e8447dc2e415" +content-hash = "7b7a3babb94d6fdc6d02f2782b26330b24f4c2ca72b8a0f9f04ef7f95ab277d2" diff --git a/pyproject.toml b/pyproject.toml index a75be46c..975d1b7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ typer = "^0.12.0" rocrate = "^0.10.0" typing-extensions = "^4.12.2" types-pyyaml = "^6.0.12.20240311" +typeguard = "^4.3.0" [tool.poetry.group.dev.dependencies] wily = "^1.25.0" @@ -53,6 +54,7 @@ pyfakefs = "^5.3.1" black = "^24.4.0" isort = "^5.13.2" deptry = "^0.17.0" +types-python-dateutil = "^2.9.0.20240316" [tool.poetry.scripts] ids = 'src.ids:app' diff --git a/src/blueprints/__init__.py b/src/blueprints/__init__.py index 2dbcbe7b..29ef7647 100644 --- a/src/blueprints/__init__.py +++ b/src/blueprints/__init__.py @@ -1,7 +1,6 @@ # pylint: disable=missing-module-docstring, R0801 from src.blueprints.common_models import GroupACL, Parameter, ParameterSet, UserACL -from src.blueprints.custom_data_types import ISODateTime, MTUrl, Username from src.blueprints.datafile import ( BaseDatafile, Datafile, diff --git a/src/blueprints/custom_data_types.py b/src/blueprints/custom_data_types.py index 7ebb73f4..d778a0f6 100644 --- a/src/blueprints/custom_data_types.py +++ b/src/blueprints/custom_data_types.py @@ -8,19 +8,11 @@ from typing import Annotated, Any from pydantic import AfterValidator, PlainSerializer, WithJsonSchema -from validators import url user_regex = re.compile( r"^[a-z]{2,4}[0-9]{3}$" # Define as a constant in case of future change ) -iso_time_regex = re.compile( - r"^(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]+)?(Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])?$" # pylint: disable=line-too-long -) -iso_date_regex = re.compile( - r"^(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])" -) - def validate_username(value: Any) -> str: """Defines a validated username, in other words, ensure that the username meets a standardised @@ -42,43 +34,3 @@ def validate_username(value: Any) -> str: PlainSerializer(lambda x: str(x), return_type=str), WithJsonSchema({"type": "string"}, mode="serialization"), ] - - -def validate_isodatetime(value: Any) -> str: - """Custom validator to ensure that the value is a string object and that it matches - the regex defined for an ISO 8601 formatted datetime string""" - if not isinstance(value, str): - raise TypeError(f'Unexpected type for ISO date/time stamp: "{type(value)}"') - if match := iso_time_regex.fullmatch(value): - return f"{match.group(0)}" - raise ValueError( - 'Passed string value "%s" is not an ISO 8601 formatted ' - "date/time string. Format should follow " - "YYYY-MM-DDTHH:MM:SS.SSSSSS+HH:MM convention" % (value) - ) - - -ISODateTime = Annotated[ - str, - AfterValidator(validate_isodatetime), - PlainSerializer(lambda x: str(x), return_type=str), - WithJsonSchema({"type": "string"}, mode="serialization"), -] - - -def validate_url(value: Any) -> str: - """Custom validator for Urls since the default pydantic ones are not compatible - with urllib""" - if not isinstance(value, str): - raise TypeError(f'Unexpected type for URL: "{type(value)}"') - if url(value): - return value - raise ValueError(f'Passed string value"{value}" is not a valid URL') - - -MTUrl = Annotated[ - str, - AfterValidator(validate_url), - PlainSerializer(lambda x: str(x), return_type=str), - WithJsonSchema({"type": "string"}, mode="serialization"), -] diff --git a/src/blueprints/datafile.py b/src/blueprints/datafile.py index 4e611e31..ed8bc91c 100644 --- a/src/blueprints/datafile.py +++ b/src/blueprints/datafile.py @@ -8,8 +8,7 @@ from pydantic import BaseModel, Field, field_serializer from src.blueprints.common_models import GroupACL, ParameterSet, UserACL -from src.blueprints.custom_data_types import MTUrl -from src.mytardis_client.common_types import DataStatus +from src.mytardis_client.common_types import DataStatus, MTUrl from src.mytardis_client.endpoints import URI diff --git a/src/blueprints/dataset.py b/src/blueprints/dataset.py index 16a1cfae..2f370abc 100644 --- a/src/blueprints/dataset.py +++ b/src/blueprints/dataset.py @@ -9,8 +9,12 @@ from pydantic import BaseModel, Field from src.blueprints.common_models import GroupACL, ParameterSet, UserACL -from src.blueprints.custom_data_types import ISODateTime, MTUrl -from src.mytardis_client.common_types import DataClassification, DataStatus +from src.mytardis_client.common_types import ( + DataClassification, + DataStatus, + ISODateTime, + MTUrl, +) from src.mytardis_client.endpoints import URI diff --git a/src/blueprints/experiment.py b/src/blueprints/experiment.py index d8495113..89a715b0 100644 --- a/src/blueprints/experiment.py +++ b/src/blueprints/experiment.py @@ -8,8 +8,12 @@ from pydantic import BaseModel, Field from src.blueprints.common_models import GroupACL, ParameterSet, UserACL -from src.blueprints.custom_data_types import ISODateTime, MTUrl -from src.mytardis_client.common_types import DataClassification, DataStatus +from src.mytardis_client.common_types import ( + DataClassification, + DataStatus, + ISODateTime, + MTUrl, +) from src.mytardis_client.endpoints import URI diff --git a/src/blueprints/project.py b/src/blueprints/project.py index c7ad5b2e..7db2f9d9 100644 --- a/src/blueprints/project.py +++ b/src/blueprints/project.py @@ -8,8 +8,13 @@ from pydantic import BaseModel, Field from src.blueprints.common_models import GroupACL, ParameterSet, UserACL -from src.blueprints.custom_data_types import ISODateTime, MTUrl, Username -from src.mytardis_client.common_types import DataClassification, DataStatus +from src.blueprints.custom_data_types import Username +from src.mytardis_client.common_types import ( + DataClassification, + DataStatus, + ISODateTime, + MTUrl, +) from src.mytardis_client.endpoints import URI diff --git a/src/cli/cmd_clean.py b/src/cli/cmd_clean.py index 5355f0f5..ab901b54 100644 --- a/src/cli/cmd_clean.py +++ b/src/cli/cmd_clean.py @@ -4,7 +4,7 @@ import sys from datetime import datetime from pathlib import Path -from typing import Annotated, Any, Optional +from typing import Annotated, Optional import typer @@ -19,6 +19,7 @@ ) from src.config.config import ConfigFromEnv from src.inspector.inspector import Inspector +from src.mytardis_client.response_data import IngestedDatafile, Replica from src.profiles.profile_register import load_profile from src.utils import log_utils from src.utils.timing import Timer @@ -26,20 +27,20 @@ logger = logging.getLogger(__name__) -def _get_verified_replica(queried_df: dict[str, Any]) -> Optional[dict[str, Any]]: +def _get_verified_replica(queried_df: IngestedDatafile) -> Optional[Replica]: """Returns the first Replica that is verified, or None if there isn't any.""" - replicas: list[dict[str, Any]] = queried_df["replicas"] + # Iterate through all replicas. If one replica is verified, then # return it. - for replica in replicas: - if replica["verified"]: + for replica in queried_df.replicas: + if replica.verified: return replica return None def _is_completed_df( df: RawDatafile, - query_result: Optional[list[dict[str, Any]]], + query_result: Optional[list[IngestedDatafile]], min_file_age: Optional[int], ) -> bool: """Checks if a datafile has been ingested, verified, and its age is higher @@ -58,7 +59,7 @@ def _is_completed_df( return False if min_file_age: # Check file age for the replica. - vtime = datetime.fromisoformat(replica["last_verified_time"]) + vtime = datetime.fromisoformat(replica.last_verified_time) days_from_vtime = (datetime.now() - vtime).days logger.info("%s was last verified %i days ago.", pth, days_from_vtime) if days_from_vtime < min_file_age: diff --git a/src/config/config.py b/src/config/config.py index fdb2040b..e76ad9e4 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -21,8 +21,8 @@ from requests import PreparedRequest from requests.auth import AuthBase -from src.blueprints.custom_data_types import MTUrl from src.blueprints.storage_boxes import StorageTypesEnum +from src.mytardis_client.common_types import MTUrl logger = logging.getLogger(__name__) diff --git a/src/ingestion_factory/factory.py b/src/ingestion_factory/factory.py index f031ce18..6d4db1ae 100644 --- a/src/ingestion_factory/factory.py +++ b/src/ingestion_factory/factory.py @@ -67,7 +67,7 @@ def __init__( self.config = config mt_rest = mt_rest or MyTardisRESTFactory(config.auth, config.connection) - self._overseer = overseer or Overseer(mt_rest) + self._overseer: Overseer = overseer or Overseer(mt_rest) self.forge = forge or Forge(mt_rest) self.smelter = smelter or Smelter( @@ -105,7 +105,7 @@ def ingest_projects( MyTardisObject.PROJECT, project.model_dump() ) if len(matching_projects) > 0: - project_uri = URI(matching_projects[0]["resource_uri"]) + project_uri = matching_projects[0].resource_uri # Would we ever get multiple matches? If so, what should we do? logging.info( 'Already ingested project "%s" as "%s". Skipping project ingestion.', @@ -145,7 +145,7 @@ def ingest_experiments( MyTardisObject.EXPERIMENT, experiment.model_dump() ) if len(matching_experiments) > 0: - experiment_uri = URI(matching_experiments[0]["resource_uri"]) + experiment_uri = matching_experiments[0].resource_uri logging.info( 'Already ingested experiment "%s" as "%s". Skipping experiment ingestion.', experiment.title, @@ -184,7 +184,7 @@ def ingest_datasets( MyTardisObject.DATASET, dataset.model_dump() ) if len(matching_datasets) > 0: - dataset_uri = URI(matching_datasets[0]["resource_uri"]) + dataset_uri = matching_datasets[0].resource_uri logging.info( 'Already ingested dataset "%s" as "%s". Skipping dataset ingestion.', dataset.description, diff --git a/src/inspector/inspector.py b/src/inspector/inspector.py index 7f1bdf20..80fe7ded 100644 --- a/src/inspector/inspector.py +++ b/src/inspector/inspector.py @@ -3,13 +3,16 @@ """ import logging -from typing import Any, Optional +from typing import Optional + +from typeguard import check_type from src.blueprints.datafile import RawDatafile from src.config.config import ConfigFromEnv from src.crucible.crucible import Crucible from src.mytardis_client.mt_rest import MyTardisRESTFactory from src.mytardis_client.objects import MyTardisObject +from src.mytardis_client.response_data import IngestedDatafile from src.overseers.overseer import Overseer from src.smelters.smelter import Smelter @@ -28,11 +31,11 @@ def __init__(self, config: ConfigFromEnv) -> None: default_schema=config.default_schema, ) crucible = Crucible(overseer) - self._overseer = overseer + self._overseer: Overseer = overseer self._smelter = smelter self._crucible = crucible - def query_datafile(self, raw_df: RawDatafile) -> Optional[list[dict[str, Any]]]: + def query_datafile(self, raw_df: RawDatafile) -> Optional[list[IngestedDatafile]]: """Partially ingests raw datafile and queries MyTardis for matching instances. Args: @@ -50,6 +53,8 @@ def query_datafile(self, raw_df: RawDatafile) -> Optional[list[dict[str, Any]]]: return None # Look up the datafile in MyTardis to check if it's ingested. - return self._overseer.get_matching_objects( + matches = self._overseer.get_matching_objects( MyTardisObject.DATAFILE, df.model_dump() ) + + return check_type(matches, list[IngestedDatafile]) diff --git a/src/mytardis_client/common_types.py b/src/mytardis_client/common_types.py index cbbee641..114b8f66 100644 --- a/src/mytardis_client/common_types.py +++ b/src/mytardis_client/common_types.py @@ -4,7 +4,11 @@ objects.py for that.""" from enum import Enum -from typing import Literal +from typing import Annotated, Literal + +from pydantic import AfterValidator + +from src.utils.validation import validate_isodatetime, validate_md5sum, validate_url # The HTTP methods supported by MyTardis. Can be used to constrain the request interfaces # to ensure that only methods that are supported by MyTardis are used. @@ -35,3 +39,19 @@ class DataStatus(Enum): READY_FOR_INGESTION = 1 INGESTED = 5 FAILED = 10 + + +MD5Sum = Annotated[ + str, + AfterValidator(validate_md5sum), +] + +ISODateTime = Annotated[ + str, + AfterValidator(validate_isodatetime), +] + +MTUrl = Annotated[ + str, + AfterValidator(validate_url), +] diff --git a/src/mytardis_client/endpoint_info.py b/src/mytardis_client/endpoint_info.py index 989301d0..5c701f55 100644 --- a/src/mytardis_client/endpoint_info.py +++ b/src/mytardis_client/endpoint_info.py @@ -6,6 +6,21 @@ from src.mytardis_client.endpoints import MyTardisEndpoint from src.mytardis_client.objects import MyTardisObject +from src.mytardis_client.response_data import ( + DatasetParameterSet, + ExperimentParameterSet, + Facility, + IngestedDatafile, + IngestedDataset, + IngestedExperiment, + IngestedProject, + Institution, + Instrument, + MyTardisIntrospection, + MyTardisResourceBase, + ProjectParameterSet, + StorageBox, +) class GetRequestProperties(BaseModel): @@ -15,14 +30,13 @@ class GetRequestProperties(BaseModel): # the response can be validated/deserialized without the requester needing # to know the correct type. But the dataclasses are currently defined outside the # mytardis_client module, and this module should ideally be self-contained. - response_obj_type: MyTardisObject + response_obj_type: type[MyTardisResourceBase] class PostRequestProperties(BaseModel): """Definition of behaviour/structure for a POST request to a MyTardis endpoint.""" expect_response_json: bool - # response_obj_type: MyTardisObject request_body_obj_type: MyTardisObject @@ -46,7 +60,7 @@ class MyTardisEndpointInfo(BaseModel): path="/project", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.PROJECT, + response_obj_type=IngestedProject, ), POST=PostRequestProperties( expect_response_json=True, @@ -58,7 +72,7 @@ class MyTardisEndpointInfo(BaseModel): path="/experiment", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.EXPERIMENT, + response_obj_type=IngestedExperiment, ), POST=PostRequestProperties( expect_response_json=True, @@ -70,7 +84,7 @@ class MyTardisEndpointInfo(BaseModel): path="/dataset", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.DATASET, + response_obj_type=IngestedDataset, ), POST=PostRequestProperties( expect_response_json=True, @@ -82,7 +96,7 @@ class MyTardisEndpointInfo(BaseModel): path="/dataset_file", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.DATAFILE, + response_obj_type=IngestedDatafile, ), POST=PostRequestProperties( expect_response_json=False, @@ -94,7 +108,7 @@ class MyTardisEndpointInfo(BaseModel): path="/institution", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.INSTITUTION, + response_obj_type=Institution, ), ), ), @@ -102,7 +116,7 @@ class MyTardisEndpointInfo(BaseModel): path="/instrument", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.INSTRUMENT, + response_obj_type=Instrument, ), ), ), @@ -110,7 +124,7 @@ class MyTardisEndpointInfo(BaseModel): path="/facility", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.FACILITY, + response_obj_type=Facility, ), ), ), @@ -118,7 +132,7 @@ class MyTardisEndpointInfo(BaseModel): path="/storagebox", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.STORAGE_BOX, + response_obj_type=StorageBox, ), ), ), @@ -126,7 +140,7 @@ class MyTardisEndpointInfo(BaseModel): path="/projectparameterset", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.PROJECT_PARAMETER_SET, + response_obj_type=ProjectParameterSet, ), POST=PostRequestProperties( expect_response_json=False, @@ -138,7 +152,7 @@ class MyTardisEndpointInfo(BaseModel): path="/experimentparameterset", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.EXPERIMENT_PARAMETER_SET, + response_obj_type=ExperimentParameterSet, ), POST=PostRequestProperties( expect_response_json=False, @@ -150,7 +164,7 @@ class MyTardisEndpointInfo(BaseModel): path="/datasetparameterset", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.DATASET_PARAMETER_SET, + response_obj_type=DatasetParameterSet, ), POST=PostRequestProperties( expect_response_json=False, @@ -162,7 +176,7 @@ class MyTardisEndpointInfo(BaseModel): path="/introspection", methods=EndpointMethods( GET=GetRequestProperties( - response_obj_type=MyTardisObject.INTROSPECTION, + response_obj_type=MyTardisIntrospection, ), ), ), diff --git a/src/mytardis_client/mt_rest.py b/src/mytardis_client/mt_rest.py index 7b4a4de2..ad97ffb4 100644 --- a/src/mytardis_client/mt_rest.py +++ b/src/mytardis_client/mt_rest.py @@ -17,7 +17,9 @@ from src.config.config import AuthConfig, ConnectionConfig from src.mytardis_client.common_types import HttpRequestMethod +from src.mytardis_client.endpoint_info import get_endpoint_info from src.mytardis_client.endpoints import URI, MyTardisEndpoint +from src.mytardis_client.response_data import MyTardisResource # Defines the valid values for the MyTardis API version MyTardisApiVersion = Literal["v1"] @@ -258,10 +260,9 @@ def request( def get( self, endpoint: MyTardisEndpoint, - object_type: type[MyTardisObjectData], query_params: Optional[dict[str, Any]] = None, meta_params: Optional[GetRequestMetaParams] = None, - ) -> tuple[list[Ingested[MyTardisObjectData]], GetResponseMeta]: + ) -> tuple[list[MyTardisResource], GetResponseMeta]: """Submit a GET request to the MyTardis API and return the response as a list of objects. Note that the response is paginated, so the function may not return all objects matching @@ -269,13 +270,15 @@ def get( returned. To get all objects matching 'query_params', use the 'get_all()' method. """ - if meta_params is None: - meta_params = GetRequestMetaParams(limit=10, offset=0) + endpoint_info = get_endpoint_info(endpoint) + if endpoint_info.methods.GET is None: + raise RuntimeError(f"GET method not supported for endpoint '{endpoint}'") - params = meta_params.model_dump() + params = query_params - if query_params is not None: - params |= query_params + if meta_params is not None: + params = params or {} + params |= meta_params.model_dump() response_data = self.request( "GET", @@ -287,7 +290,7 @@ def get( response_meta = GetResponseMeta.model_validate(response_json["meta"]) - objects: list[Ingested[MyTardisObjectData]] = [] + objects: list[MyTardisResource] = [] response_objects = response_json.get("objects") if response_objects is None: @@ -298,23 +301,23 @@ def get( f"Response: {response_json}" ) + object_type = endpoint_info.methods.GET.response_obj_type + if not isinstance(response_objects, list): response_objects = [response_objects] for object_json in response_objects: obj = object_type.model_validate(object_json) - resource_uri = URI(object_json["resource_uri"]) - objects.append(Ingested(obj=obj, resource_uri=resource_uri)) + objects.append(obj) return objects, response_meta def get_all( self, endpoint: MyTardisEndpoint, - object_type: type[MyTardisObjectData], query_params: Optional[dict[str, Any]] = None, batch_size: int = 500, - ) -> tuple[list[Ingested[MyTardisObjectData]], int]: + ) -> tuple[list[MyTardisResource], int]: """Get all objects of the given type that match 'query_params'. Sends repeated GET requests to the MyTardis API until all objects have been retrieved. @@ -322,14 +325,13 @@ def get_all( each request """ - objects: list[Ingested[MyTardisObjectData]] = [] + objects: list[MyTardisResource] = [] while True: request_meta = GetRequestMetaParams(limit=batch_size, offset=len(objects)) batch_objects, response_meta = self.get( endpoint=endpoint, - object_type=object_type, query_params=query_params, meta_params=request_meta, ) diff --git a/src/mytardis_client/response_data.py b/src/mytardis_client/response_data.py index ac3c00ad..bdca559e 100644 --- a/src/mytardis_client/response_data.py +++ b/src/mytardis_client/response_data.py @@ -1,18 +1,74 @@ """Dataclasses for validating/storing MyTardis API response data.""" -from typing import Optional +from pathlib import Path +from typing import Any, Optional, Protocol from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Self +from src.mytardis_client.common_types import ( + DataClassification, + ISODateTime, + MD5Sum, + MTUrl, +) +from src.mytardis_client.endpoints import URI from src.mytardis_client.objects import MyTardisObject -class MyTardisIntrospection(BaseModel): - """MyTardis introspection data. +# pylint: disable=too-few-public-methods +class MyTardisResource(Protocol): + """Protocol for MyTardis resources.""" - NOTE: this class relies on data from the MyTardis introspection API and therefore - can't be instantiated without a request to the specific MyTardis instance. + id: int + resource_uri: URI + + +class MyTardisResourceBase(BaseModel): + """Base class for data retrieved from MyTardis, associated with an ingested object.""" + + id: int + resource_uri: URI + + +class Group(MyTardisResourceBase): + """Metadata associated with a group in MyTardis.""" + + name: str + + +class Facility(MyTardisResourceBase): + """Metadata associated with a facility in MyTardis.""" + + created_time: ISODateTime + manager_group: Group + modified_time: ISODateTime + name: str + + +class Institution(MyTardisResourceBase): + """Metadata associated with an institution in MyTardis.""" + + aliases: list[str] + identifiers: list[str] + name: str + + +class Instrument(MyTardisResourceBase): + """Metadata associated with an instrument in MyTardis.""" + + created_time: ISODateTime + facility: Facility + modified_time: ISODateTime + name: str + + +class MyTardisIntrospection(MyTardisResourceBase): + """MyTardis introspection data (the configuration of the MyTardis instance). + + NOTE: this class relies on data from the MyTardis introspection API and + therefore can't be instantiated without a request to the specific MyTardis + instance. """ model_config = ConfigDict(use_enum_values=False) @@ -29,6 +85,23 @@ class MyTardisIntrospection(BaseModel): projects_enabled: bool profiles_enabled: bool + @model_validator(mode="before") + @classmethod + def validate_model(cls, data: Any) -> Any: + """Adapter for the raw data, as the ID from the introspection endpoint + comes back as None""" + + if isinstance(data, dict): + dummy_id = 0 + resource_uri = data.get("resource_uri") + if isinstance(resource_uri, str): + data["resource_uri"] = resource_uri.replace("None", str(dummy_id)) + + if "id" not in data: + data["id"] = dummy_id + + return data + @model_validator(mode="after") def validate_consistency(self) -> Self: """Check that the introspection data is consistent.""" @@ -44,3 +117,142 @@ def validate_consistency(self) -> Self: ) return self + + +class ParameterName(MyTardisResourceBase): + """Schema parameter information""" + + full_name: str + immutable: bool + is_searchable: bool + name: str + parent_schema: URI = Field(serialization_alias="schema") + sensitive: bool + units: str + + +class Replica(MyTardisResourceBase): + """Metadata associated with a Datafile replica in MyTardis.""" + + created_time: ISODateTime + datafile: URI + last_verified_time: ISODateTime + location: str + uri: str # Note: not a MyTardis resource URI + verified: bool + + +class Schema(MyTardisResourceBase): + """Metadata associated with a metadata schema in MyTardis.""" + + hidden: bool + immutable: bool + name: str + namespace: MTUrl # e.g. "http://130.216.253.65/noel-test-exp-schema", + parameter_names: list[ParameterName] + + +class StorageBoxOption(MyTardisResourceBase): + """Data associated with a storage box option in MyTardis.""" + + key: str + storage_box: URI + value: str + value_type: str + + +class StorageBox(MyTardisResourceBase): + """Metadata associated with a storage box in MyTardis.""" + + attributes: list[str] + description: str + django_storage_class: str + max_size: Optional[int] + name: str + options: list[StorageBoxOption] + status: str + + +class User(MyTardisResourceBase): + """Dataa associated with a user in MyTardis.""" + + email: Optional[str] + first_name: Optional[str] + groups: list[Group] + last_name: Optional[str] + username: str + + +class ProjectParameterSet(MyTardisResourceBase): + """Metadata associated with a project parameter set in MyTardis.""" + + +class ExperimentParameterSet(MyTardisResourceBase): + """Metadata associated with an experiment parameter set in MyTardis.""" + + +class DatasetParameterSet(MyTardisResourceBase): + """Metadata associated with a dataset parameter set in MyTardis.""" + + +class DatafileParameterSet(MyTardisResourceBase): + """Metadata associated with a datafile parameter set in MyTardis.""" + + +class IngestedProject(MyTardisResourceBase): + """Metadata associated with a project that has been ingested into MyTardis.""" + + classification: DataClassification + description: str + identifiers: Optional[list[str]] + institution: list[URI] + locked: bool + name: str + principal_investigator: str + + +class IngestedExperiment(MyTardisResourceBase): + """Metadata associated with an experiment that has been ingested into MyTardis.""" + + classification: int + description: str + identifiers: Optional[list[str]] + institution_name: str + projects: list[IngestedProject] + title: str + + +class IngestedDataset(MyTardisResourceBase): + """Metadata associated with a dataset that has been ingested into MyTardis.""" + + classification: DataClassification + created_time: ISODateTime + description: str + directory: Path + experiments: list[URI] + identifiers: list[str] + immutable: bool + instrument: Instrument + modified_time: ISODateTime + parameter_sets: list[DatasetParameterSet] + public_access: bool + + +class IngestedDatafile(MyTardisResourceBase): + """Metadata associated with a datafile that has been ingested into MyTardis.""" + + created_time: Optional[ISODateTime] + dataset: URI + deleted: bool + deleted_time: Optional[ISODateTime] + directory: Path + filename: str + identifiers: Optional[list[str]] + md5sum: MD5Sum + mimetype: str + modification_time: Optional[ISODateTime] + parameter_sets: list[DatafileParameterSet] + public_access: bool + replicas: list[Replica] + size: int + version: int diff --git a/src/overseers/overseer.py b/src/overseers/overseer.py index 4075bc81..e04cdcb4 100644 --- a/src/overseers/overseer.py +++ b/src/overseers/overseer.py @@ -6,13 +6,12 @@ from collections.abc import Generator from typing import Any -from pydantic import ValidationError -from requests.exceptions import HTTPError +from typeguard import check_type from src.mytardis_client.endpoints import URI, MyTardisEndpoint from src.mytardis_client.mt_rest import MyTardisRESTFactory from src.mytardis_client.objects import MyTardisObject, get_type_info -from src.mytardis_client.response_data import MyTardisIntrospection +from src.mytardis_client.response_data import MyTardisIntrospection, MyTardisResource from src.utils.types.singleton import Singleton logger = logging.getLogger(__name__) @@ -149,44 +148,26 @@ def _get_matches_from_mytardis( self, object_type: MyTardisObject, query_params: dict[str, str], - ) -> list[dict[str, Any]]: + ) -> list[MyTardisResource]: """Get objects from MyTardis that match the given query parameters""" endpoint = get_default_endpoint(object_type) try: - response = self.rest_factory.request( - "GET", endpoint=endpoint, params=query_params - ) - except HTTPError as error: - logger.warning( - ( - "Failed HTTP request from Overseer.get_objects call\n" - f"object_type = {object_type}\n" - f"query_params = {query_params}" - ), - exc_info=True, - ) - raise error + objects, _ = self.rest_factory.get(endpoint, query_params) except Exception as error: - logger.error( - ( - "Non-HTTP exception in Overseer.get_objects call\n" - f"object_type = {object_type}\n" - f"search_target = {query_params}" - ), - exc_info=True, - ) - raise error + raise RuntimeError( + "Failed to query matching objects from MyTardis in Overseer." + f"Object type: {object_type}. Query: {query_params}" + ) from error - response_json: dict[str, Any] = response.json() - return list(response_json["objects"]) + return objects def get_matching_objects( self, object_type: MyTardisObject, object_data: dict[str, str], - ) -> list[dict[str, Any]]: + ) -> list[MyTardisResource]: """Retrieve objects from MyTardis with field values matching the ones in "field_values" The function extracts the type-dependent match keys from 'object_data' and uses them to @@ -197,7 +178,7 @@ def get_matching_objects( for match_keys in matchers: if objects := self._get_matches_from_mytardis(object_type, match_keys): - return list(objects) + return objects return [] @@ -220,35 +201,8 @@ def get_uris( A list of object URIs from the search request made. """ objects = self._get_matches_from_mytardis(object_type, match_keys) - if not objects: - return [] - - return_list: list[URI] = [] - - for obj in objects: - try: - uri = URI(obj["resource_uri"]) - except KeyError as error: - logger.error( - ( - "Malformed return from MyTardis. No resource_uri found for " - f"{object_type} searching with {match_keys}. Object in " - f"question is {obj}." - ), - exc_info=True, - ) - raise error - except ValidationError as error: - logger.error( - ( - "Malformed return from MyTardis. Unable to conform " - "resource_uri into URI format" - ), - exc_info=True, - ) - raise error - return_list.append(uri) - return return_list + + return [obj.resource_uri for obj in objects] def get_uris_by_identifier( self, object_type: MyTardisObject, identifier: str @@ -265,28 +219,14 @@ def fetch_mytardis_setup(self) -> MyTardisIntrospection: Requests introspection info from MyTardis instance configured in connection """ - response = self.rest_factory.request("GET", "/introspection") + objects, _ = self.rest_factory.get("/introspection") - response_dict = response.json() - if response_dict == {} or response_dict["objects"] == []: - raise ValueError( - ( - "MyTardis introspection did not return any data when called from " - "ConfigFromEnv.get_mytardis_setup" - ) - ) - if len(response_dict["objects"]) > 1: + if len(objects) != 1: raise ValueError( ( - """MyTardis introspection returned more than one object when called from - ConfigFromEnv.get_mytardis_setup\n - Returned response was: %s""", - response_dict, + f"Expected a single object from introspection endpoint, but got {len(objects)}." + f"MyTardis may be misconfigured. Objects returned: {objects}" ) ) - introspection = MyTardisIntrospection.model_validate( - response_dict["objects"][0] - ) - - return introspection + return check_type(objects[0], MyTardisIntrospection) diff --git a/src/profiles/abi_music/parsing.py b/src/profiles/abi_music/parsing.py index a6e84cd2..b8313429 100644 --- a/src/profiles/abi_music/parsing.py +++ b/src/profiles/abi_music/parsing.py @@ -13,14 +13,13 @@ from slugify import slugify from src.blueprints.common_models import GroupACL, UserACL -from src.blueprints.custom_data_types import MTUrl from src.blueprints.datafile import RawDatafile from src.blueprints.dataset import RawDataset from src.blueprints.experiment import RawExperiment from src.blueprints.project import RawProject from src.extraction.manifest import IngestionManifest from src.extraction.metadata_extractor import IMetadataExtractor -from src.mytardis_client.common_types import DataClassification +from src.mytardis_client.common_types import DataClassification, MTUrl from src.profiles.abi_music.abi_music_consts import ( ABI_MUSIC_DATASET_RAW_SCHEMA, ABI_MUSIC_DATASET_ZARR_SCHEMA, diff --git a/src/profiles/ro_crate/ro_crate_parser.py b/src/profiles/ro_crate/ro_crate_parser.py index f3d2720c..9992bc3a 100644 --- a/src/profiles/ro_crate/ro_crate_parser.py +++ b/src/profiles/ro_crate/ro_crate_parser.py @@ -17,7 +17,6 @@ from rocrate.utils import as_list, get_norm_value from src.blueprints.common_models import GroupACL, UserACL -from src.blueprints.custom_data_types import validate_isodatetime, validate_url from src.blueprints.datafile import RawDatafile # pylint: disable=duplicate-code from src.blueprints.dataset import RawDataset # pylint: disable=duplicate-code from src.blueprints.experiment import RawExperiment # pylint: disable=duplicate-code @@ -38,6 +37,7 @@ from src.utils.filesystem import checksums, filters from src.utils.filesystem.filesystem_nodes import DirectoryNode from src.utils.filesystem.filters import PathFilterSet +from src.utils.validation import validate_isodatetime, validate_url logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # set the level for which this logger will be printed. diff --git a/src/smelters/smelter.py b/src/smelters/smelter.py index 39fd22cb..24fb4e19 100644 --- a/src/smelters/smelter.py +++ b/src/smelters/smelter.py @@ -11,12 +11,12 @@ from pydantic import ValidationError from src.blueprints.common_models import Parameter, ParameterSet -from src.blueprints.custom_data_types import MTUrl from src.blueprints.datafile import RawDatafile, RefinedDatafile from src.blueprints.dataset import RawDataset, RefinedDataset from src.blueprints.experiment import RawExperiment, RefinedExperiment from src.blueprints.project import RawProject, RefinedProject from src.config.config import GeneralConfig, SchemaConfig +from src.mytardis_client.common_types import MTUrl from src.mytardis_client.endpoints import URI from src.overseers.overseer import MYTARDIS_PROJECTS_DISABLED_MESSAGE, Overseer diff --git a/src/utils/types/type_helpers.py b/src/utils/types/type_helpers.py new file mode 100644 index 00000000..c1f8574a --- /dev/null +++ b/src/utils/types/type_helpers.py @@ -0,0 +1,16 @@ +"""Helpers for working with types and type-checking.""" + +from typing import Any, TypeGuard, TypeVar + +T = TypeVar("T") + + +def elements_are(elems: list[Any], query_type: type[T]) -> TypeGuard[list[T]]: + """Check if all elements in a list are of a certain type.""" + return all(isinstance(elem, query_type) for elem in elems) + + +def is_list_of(obj: Any, query_type: type[T]) -> TypeGuard[list[T]]: + """Check if an object is a list with elements of a certain type.""" + + return isinstance(obj, list) and all(isinstance(entry, query_type) for entry in obj) diff --git a/src/utils/validation.py b/src/utils/validation.py new file mode 100644 index 00000000..d5e202e7 --- /dev/null +++ b/src/utils/validation.py @@ -0,0 +1,54 @@ +"""Utility functions for validating data.""" + +from typing import Any + +from dateutil import parser +from validators import url + + +def is_hex(value: str) -> bool: + """Check if a string is a valid hexadecimal string.""" + try: + _ = int(value, 16) + except ValueError: + return False + + return True + + +def validate_md5sum(value: str) -> str: + """Check that the input string is a well-formed MD5Sum.""" + if len(value) != 32: + raise ValueError("MD5Sum must contain exactly 32 characters") + if not is_hex(value): + raise ValueError("MD5Sum must be a valid hexadecimal string") + + return value + + +def validate_isodatetime(value: Any) -> str: + """Custom validator to ensure that the value is a string object and that it matches + the regex defined for an ISO 8601 formatted datetime string""" + if not isinstance(value, str): + raise TypeError(f'Unexpected type for ISO date/time stamp: "{type(value)}"') + + try: + _ = parser.isoparse(value) + except ValueError as e: + raise ValueError( + f'Passed string value "{value}" is not an ISO 8601 formatted ' + "date/time string. Format should follow " + "YYYY-MM-DDTHH:MM:SS.SSSSSS+HH:MM convention" + ) from e + + return value + + +def validate_url(value: Any) -> str: + """Custom validator for Urls since the default pydantic ones are not compatible + with urllib""" + if not isinstance(value, str): + raise TypeError(f'Unexpected type for URL: "{type(value)}"') + if url(value): + return value + raise ValueError(f'Passed string value"{value}" is not a valid URL') diff --git a/tests/fixtures/fixtures_constants.py b/tests/fixtures/fixtures_constants.py index 1b76b12f..dd4bdefb 100644 --- a/tests/fixtures/fixtures_constants.py +++ b/tests/fixtures/fixtures_constants.py @@ -8,8 +8,9 @@ from pytz import BaseTzInfo from src.blueprints.common_models import GroupACL, Parameter, UserACL -from src.blueprints.custom_data_types import ISODateTime, Username +from src.blueprints.custom_data_types import Username from src.blueprints.storage_boxes import StorageTypesEnum +from src.mytardis_client.common_types import ISODateTime from src.mytardis_client.endpoints import URI diff --git a/tests/fixtures/fixtures_dataclasses.py b/tests/fixtures/fixtures_dataclasses.py index fb9f3d44..56318318 100644 --- a/tests/fixtures/fixtures_dataclasses.py +++ b/tests/fixtures/fixtures_dataclasses.py @@ -9,7 +9,7 @@ from pytest import fixture from src.blueprints.common_models import GroupACL, Parameter, ParameterSet, UserACL -from src.blueprints.custom_data_types import ISODateTime, Username +from src.blueprints.custom_data_types import Username from src.blueprints.datafile import ( Datafile, DatafileReplica, @@ -19,7 +19,7 @@ from src.blueprints.dataset import Dataset, RawDataset, RefinedDataset from src.blueprints.experiment import Experiment, RawExperiment, RefinedExperiment from src.blueprints.project import Project, RawProject, RefinedProject -from src.mytardis_client.common_types import DataClassification +from src.mytardis_client.common_types import DataClassification, ISODateTime from src.mytardis_client.endpoints import URI diff --git a/tests/fixtures/fixtures_responses.py b/tests/fixtures/fixtures_responses.py index 8e5a8a69..524f59be 100644 --- a/tests/fixtures/fixtures_responses.py +++ b/tests/fixtures/fixtures_responses.py @@ -165,13 +165,13 @@ def datafile_get_response_paginated_second() -> dict[str, Any]: def get_project_details( project_ids: list[str], project_description: str, - project_institutions: list[str], project_name: str, project_principal_investigator: str, project_url: str, ) -> List[Dict[str, Any]]: return [ { + "classification": 25, "created_by": "api/v1/user/1/", "datafile_count": 2, "dataset_count": 1, @@ -181,7 +181,9 @@ def get_project_details( "experiment_count": 1, "id": 1, "identifiers": project_ids, - "institution": project_institutions, + "institution": [ + "/api/v1/institution/1/", + ], "locked": False, "name": project_name, "parameter_sets": [], diff --git a/tests/test_custom_data_types.py b/tests/test_custom_data_types.py index 9c3cb000..d18d249a 100644 --- a/tests/test_custom_data_types.py +++ b/tests/test_custom_data_types.py @@ -3,25 +3,16 @@ # nosec assert_used # flake8: noqa S101 -from datetime import datetime - import pytest -from dateutil import tz from pydantic import BaseModel -from src.blueprints.custom_data_types import ISODateTime, Username - -NZT = tz.gettz("Pacific/Auckland") +from src.blueprints.custom_data_types import Username class DummyUsernames(BaseModel): user: Username -class DummyISODateTime(BaseModel): - iso_time: ISODateTime - - @pytest.mark.parametrize("upis", ["test001", "ts001", "tst001"]) def test_UPI_is_good(upis: Username) -> None: # pylint: disable=invalid-name test_class = DummyUsernames(user=upis) @@ -52,36 +43,3 @@ def test_malformed_UPI(upis: Username) -> None: # pylint: disable=invalid-name "input_type=str]\n " ) ) in str(e_info.value) - - -@pytest.mark.parametrize( - "iso_strings, expected", - ( - ("2022-01-01T12:00:00", "2022-01-01T12:00:00"), - ("2022-01-01T12:00:00+12:00", "2022-01-01T12:00:00+12:00"), - ("2022-01-01T12:00:00.0+12:00", "2022-01-01T12:00:00.0+12:00"), - ("2022-01-01T12:00:00.00+12:00", "2022-01-01T12:00:00.00+12:00"), - ("2022-01-01T12:00:00.000+12:00", "2022-01-01T12:00:00.000+12:00"), - ("2022-01-01T12:00:00.0000+12:00", "2022-01-01T12:00:00.0000+12:00"), - ("2022-01-01T12:00:00.00000+12:00", "2022-01-01T12:00:00.00000+12:00"), - ( - "2022-01-01T12:00:00.000000+12:00", - "2022-01-01T12:00:00.000000+12:00", - ), - ( - datetime(2022, 1, 1, 12, 00, 00, 000000).isoformat(), - "2022-01-01T12:00:00", - ), - (datetime(2022, 1, 1, tzinfo=NZT).isoformat(), "2022-01-01T00:00:00+13:00"), - ( - datetime(2022, 1, 1, 12, 00, 00, tzinfo=NZT).isoformat(), - "2022-01-01T12:00:00+13:00", - ), - ), -) -def test_good_ISO_DateTime_string( # pylint: disable=invalid-name - iso_strings: str, - expected: str, -) -> None: - test_class = DummyISODateTime(iso_time=iso_strings) - assert test_class.iso_time == expected diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 963c9251..fb16c3be 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -5,7 +5,7 @@ from pytest import fixture from src.blueprints.common_models import Parameter, ParameterSet -from src.blueprints.custom_data_types import MTUrl +from src.mytardis_client.common_types import MTUrl from src.mytardis_client.endpoints import URI diff --git a/tests/test_mytardis_client_rest_factory.py b/tests/test_mytardis_client_rest_factory.py index d70eaae1..ed6f5235 100644 --- a/tests/test_mytardis_client_rest_factory.py +++ b/tests/test_mytardis_client_rest_factory.py @@ -23,6 +23,8 @@ MyTardisRESTFactory, sanitize_params, ) +from src.mytardis_client.response_data import IngestedDatafile +from src.utils.types.type_helpers import is_list_of logger = logging.getLogger(__name__) logger.propagate = True @@ -98,12 +100,13 @@ def test_mytardis_client_rest_get_single( json=datafile_get_response_single, ) - datafiles, meta = mt_client.get("/dataset_file", object_type=Datafile) + datafiles, meta = mt_client.get("/dataset_file") assert isinstance(datafiles, list) + assert is_list_of(datafiles, IngestedDatafile) assert len(datafiles) == 1 assert datafiles[0].resource_uri == URI("/api/v1/dataset_file/0/") - assert datafiles[0].obj.filename == "test_filename.txt" + assert datafiles[0].filename == "test_filename.txt" assert isinstance(meta, GetResponseMeta) assert meta.total_count == 1 @@ -126,11 +129,10 @@ def test_mytardis_client_rest_get_multi( datafiles, meta = mt_client.get( "/dataset_file", - object_type=Datafile, meta_params=GetRequestMetaParams(offset=2, limit=3), ) - assert isinstance(datafiles, list) + assert is_list_of(datafiles, IngestedDatafile) assert len(datafiles) == 30 assert (isinstance(df, Datafile) for df in datafiles) @@ -138,9 +140,9 @@ def test_mytardis_client_rest_get_multi( assert datafiles[1].resource_uri == URI("/api/v1/dataset_file/1/") assert datafiles[2].resource_uri == URI("/api/v1/dataset_file/2/") - assert datafiles[0].obj.filename == "test_filename_0.txt" - assert datafiles[1].obj.filename == "test_filename_1.txt" - assert datafiles[2].obj.filename == "test_filename_2.txt" + assert datafiles[0].filename == "test_filename_0.txt" + assert datafiles[1].filename == "test_filename_1.txt" + assert datafiles[2].filename == "test_filename_2.txt" assert isinstance(meta, GetResponseMeta) assert meta.total_count == 30 @@ -172,17 +174,16 @@ def test_mytardis_client_rest_get_all( datafiles, total_count = mt_client.get_all( "/dataset_file", - object_type=Datafile, batch_size=20, ) - assert isinstance(datafiles, list) + assert is_list_of(datafiles, IngestedDatafile) assert len(datafiles) == 30 assert (isinstance(df, Datafile) for df in datafiles) for i in range(30): assert datafiles[i].resource_uri == URI(f"/api/v1/dataset_file/{i}/") - assert datafiles[i].obj.filename == f"test_filename_{i}.txt" + assert datafiles[i].filename == f"test_filename_{i}.txt" assert total_count == 30 @@ -296,7 +297,6 @@ def test_mytardis_client_get_params_are_sanitized( _ = mt_client.get( "/dataset_file", - object_type=Datafile, query_params={"dataset": URI("/api/v1/dataset/0/")}, meta_params=GetRequestMetaParams(limit=1, offset=0), ) diff --git a/tests/test_overseers.py b/tests/test_overseers.py index bd8e0159..4ea1ff00 100644 --- a/tests/test_overseers.py +++ b/tests/test_overseers.py @@ -18,8 +18,9 @@ from src.mytardis_client.endpoints import URI from src.mytardis_client.mt_rest import MyTardisRESTFactory from src.mytardis_client.objects import MyTardisObject -from src.mytardis_client.response_data import MyTardisIntrospection +from src.mytardis_client.response_data import IngestedProject, MyTardisIntrospection from src.overseers.overseer import Overseer +from src.utils.types.type_helpers import is_list_of logger = logging.getLogger(__name__) logger.propagate = True @@ -67,21 +68,26 @@ def test_get_matches_from_mytardis( ) # pylint: disable=protected-access - assert ( - overseer._get_matches_from_mytardis( - object_type, - {"name": project_name}, - ) - == project_response_dict["objects"] + expected_projects = [ + IngestedProject.model_validate(proj) + for proj in project_response_dict["objects"] + ] + + retrieved_projects = overseer._get_matches_from_mytardis( + object_type, + {"name": project_name}, ) + assert is_list_of(retrieved_projects, IngestedProject) + assert retrieved_projects == expected_projects + # pylint: disable=protected-access assert ( overseer._get_matches_from_mytardis( object_type, {"identifier": project_identifiers[0]}, ) - == project_response_dict["objects"] + == expected_projects ) Overseer.clear() @@ -89,7 +95,6 @@ def test_get_matches_from_mytardis( @responses.activate def test_get_objects_http_error( - caplog: LogCaptureFixture, connection: ConnectionConfig, overseer: Overseer, ) -> None: @@ -116,15 +121,10 @@ def test_get_objects_http_error( ], status=504, ) - caplog.set_level(logging.WARNING) - with pytest.raises(HTTPError): + with pytest.raises(RuntimeError): overseer.get_matching_objects(object_type, {"name": search_string}) - assert "Failed HTTP request" in caplog.text - assert "Overseer" in caplog.text - assert f"{object_type}" in caplog.text - Overseer.clear() @@ -142,7 +142,7 @@ def test_get_objects_general_error( f"object_type = {object_type}\n" "query_params" ) - with pytest.raises(IOError): + with pytest.raises(RuntimeError): _ = overseer.get_matching_objects(object_type, {"name": search_string}) assert error_str in caplog.text @@ -257,15 +257,14 @@ def test_get_uris_no_objects( @responses.activate def test_get_uris_malformed_return_dict( - caplog: LogCaptureFixture, connection: ConnectionConfig, overseer: Overseer, project_response_dict: dict[str, Any], ) -> None: - caplog.set_level(logging.ERROR) + test_dict = project_response_dict test_dict["objects"][0].pop("resource_uri") - object_type = MyTardisObject.PROJECT + endpoint = "project" search_string = "Project_1" responses.add( @@ -290,27 +289,21 @@ def test_get_uris_malformed_return_dict( ], status=200, ) - error_str = ( - "Malformed return from MyTardis. No resource_uri found for " - f"{object_type} searching with {search_string}" - ) - with pytest.raises(KeyError): + + with pytest.raises(RuntimeError): _ = overseer.get_uris_by_identifier( MyTardisObject.PROJECT, search_string, ) - assert error_str in caplog.text + Overseer.clear() @responses.activate def test_get_uris_ensure_http_errors_caught_by_get_objects( - caplog: LogCaptureFixture, connection: ConnectionConfig, overseer: Overseer, ) -> None: - caplog.set_level(logging.WARNING) - object_type = MyTardisObject.PROJECT endpoint = "project" search_string = "Project_1" responses.add( @@ -334,15 +327,12 @@ def test_get_uris_ensure_http_errors_caught_by_get_objects( status=504, ) - with pytest.raises(HTTPError): + with pytest.raises(RuntimeError): overseer.get_uris_by_identifier( MyTardisObject.PROJECT, search_string, ) - assert "Failed HTTP request from Overseer" in caplog.text - assert f"{object_type}" in caplog.text - Overseer.clear() @@ -354,7 +344,7 @@ def test_get_uris_general_error( mock_mytardis_api_request.side_effect = IOError() object_type = MyTardisObject.PROJECT search_string = "Project_1" - with pytest.raises(IOError): + with pytest.raises(RuntimeError): _ = overseer.get_uris_by_identifier( object_type, search_string, @@ -422,7 +412,6 @@ def test_get_mytardis_setup_general_error( @responses.activate def test_get_mytardis_setup_no_objects( - caplog: LogCaptureFixture, overseer: Overseer, connection: ConnectionConfig, response_dict_not_found: dict[str, Any], @@ -436,14 +425,10 @@ def test_get_mytardis_setup_no_objects( json=(response_dict_not_found), status=200, ) - caplog.set_level(logging.ERROR) - error_str = ( - "MyTardis introspection did not return any data when called from " - "ConfigFromEnv.get_mytardis_setup" - ) - with pytest.raises(ValueError, match=error_str): + + with pytest.raises(ValueError): _ = overseer.fetch_mytardis_setup() - assert error_str in caplog.text + Overseer.clear() diff --git a/tests/test_utils_validation.py b/tests/test_utils_validation.py new file mode 100644 index 00000000..7c460338 --- /dev/null +++ b/tests/test_utils_validation.py @@ -0,0 +1,114 @@ +"""Tests for the validators module.""" + +# pylint: disable=missing-function-docstring +# nosec assert_used +# flake8: noqa S101 + +from datetime import datetime + +import pytest +from dateutil import tz + +from src.utils.validation import is_hex, validate_isodatetime, validate_md5sum + +NZT = tz.gettz("Pacific/Auckland") + + +def test_is_hex_valid_hex() -> None: + """Test is_hex function with valid input.""" + assert is_hex("0") + assert is_hex("1") + assert is_hex("a") + assert is_hex("f") + assert is_hex("A") + assert is_hex("F") + assert is_hex("ff") + assert is_hex("FF") + assert is_hex("0f") + + assert is_hex("0x0") + assert is_hex("0X0") + assert is_hex("0x0f") + assert is_hex("0X0F") + + assert is_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF") + assert is_hex("04578945702952854684859406729090019481057396748385407482570582857") + + +def test_is_hex_invalid_hex() -> None: + """Test is_hex function with invalid input.""" + assert not is_hex("g") + assert not is_hex("G") + assert not is_hex("z") + assert not is_hex("Z") + assert not is_hex("0x") + assert not is_hex("0X") + assert not is_hex("0fG") + assert not is_hex("0FG") + assert not is_hex("0x0g") + assert not is_hex("0X0g") + assert not is_hex("abcdefABCDEF0123456789g") + + +@pytest.mark.parametrize( + "value", + [ + "0123456789abcdef0123456789abcdef", + "0123456789ABCDEF0123456789ABCDEF", + "ffffffffffffffffffffffffffffffff", + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", + "00000000000000000000000000000000", + ], +) +def test_md5sum_valid(value: str) -> None: + """Test that a valid MD5Sum is accepted and serialized correctly.""" + + assert validate_md5sum(value) == value + + +@pytest.mark.parametrize( + "value", + [ + "0123456789abcdef0123456789abcdefa", + "0123456789ABCDEF0123456789ABCDEFA", + "fffffffffffffffffffffffffffffff", + "fffffffffffffffffffffffffffffffff", + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "AAAAA", + "00000", + "A", + "F", + "0", + "&*%^%$%^$%&*&^()", + "banana", + ], +) +def test_md5sum_invalid(value: str) -> None: + """Test that an invalid MD5Sum is rejected.""" + + with pytest.raises(ValueError): + _ = validate_md5sum(value) + + +@pytest.mark.parametrize( + "valid_iso_dt", + [ + "2022-01-01T12:00:00", + "2022-01-01T12:00:00+12:00", + "2022-01-01T12:00:00.0+12:00", + "2022-01-01T12:00:00.00+12:00", + "2022-01-01T12:00:00.000+12:00", + "2022-01-01T12:00:00.0000+12:00", + "2022-01-01T12:00:00.00000+12:00", + "2022-01-01T12:00:00.000000+12:00", + datetime(2022, 1, 1, 12, 00, 00, 000000).isoformat(), + datetime(2022, 1, 1, tzinfo=NZT).isoformat(), + datetime(2022, 1, 1, 12, 00, 00, tzinfo=NZT).isoformat(), + ], +) +def test_good_ISO_DateTime_string( # pylint: disable=invalid-name + valid_iso_dt: str, +) -> None: + assert validate_isodatetime(valid_iso_dt) == valid_iso_dt