From f731c1b949036d07c0cd2d0b21122241bc61093f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Tue, 26 Sep 2023 13:35:11 +0200 Subject: [PATCH] WIP update format updating --- bioimageio/spec/_internal/base_nodes.py | 6 ++- .../spec/_internal/validation_context.py | 8 +++- bioimageio/spec/generic/v0_2.py | 42 ++++++++++++++++--- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/bioimageio/spec/_internal/base_nodes.py b/bioimageio/spec/_internal/base_nodes.py index f293fe7f4..b0e9d6893 100644 --- a/bioimageio/spec/_internal/base_nodes.py +++ b/bioimageio/spec/_internal/base_nodes.py @@ -201,7 +201,11 @@ def __pydantic_init_subclass__(cls, **kwargs: Any): @classmethod def _update_context(cls, context: InternalValidationContext, data: RdfContent) -> None: - pass + # set original format if possible + original_format = data.get("format_version") + if "original_format" not in context and isinstance(original_format, str) and original_format.count(".") == 2: + context["original_format"] = cast(Tuple[int, int, int], tuple(map(int, original_format.split(".")))) + assert len(context["original_format"]) == 3 @classmethod def model_validate( diff --git a/bioimageio/spec/_internal/validation_context.py b/bioimageio/spec/_internal/validation_context.py index cf7c4012a..333423291 100644 --- a/bioimageio/spec/_internal/validation_context.py +++ b/bioimageio/spec/_internal/validation_context.py @@ -32,6 +32,9 @@ class InternalValidationContext(TypedDict): warning_level: WarningLevel """raise warnings of severity s as validation errors if s >= `warning_level`""" + original_format: NotRequired[Tuple[int, int, int]] + """original format version of the validation data (set dynamically during validation of resource descriptions).""" + collection_base_content: NotRequired[Dict[str, Any]] """Collection base content (set dynamically during validation of collection resource descriptions).""" @@ -52,7 +55,8 @@ def get_internal_validation_context( file_name=file_name or given_context.get("file_name", "rdf.bioimageio.yaml"), warning_level=warning_level or given_context.get(WARNING_LEVEL_CONTEXT_KEY, ERROR), ) - if "collection_base_content" in given_context: - ret["collection_base_content"] = given_context["collection_base_content"] + for k in {"original_format", "collection_base_content"}: # TypedDict.__optional_keys__ requires py>=3.9 + if k in given_context: + ret[k] = given_context[k] return ret diff --git a/bioimageio/spec/generic/v0_2.py b/bioimageio/spec/generic/v0_2.py index 3f21a3646..0b7b3f494 100644 --- a/bioimageio/spec/generic/v0_2.py +++ b/bioimageio/spec/generic/v0_2.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, TypeVar, Union +from typing import Any, List, Literal, Optional, Tuple, TypeVar, Union -from annotated_types import Len, LowerCase, MaxLen, MinLen +from annotated_types import Len, LowerCase, MaxLen, MinLen, Predicate from pydantic import EmailStr, Field, FieldValidationInfo, HttpUrl, field_validator from typing_extensions import Annotated @@ -19,7 +19,7 @@ RdfContent, Version, ) -from bioimageio.spec._internal.validation_context import InternalValidationContext +from bioimageio.spec._internal.validation_context import InternalValidationContext, get_internal_validation_context from bioimageio.spec.generic.v0_2_converter import convert_from_older_format KNOWN_SPECIFIC_RESOURCE_TYPES = ("application", "collection", "dataset", "model", "notebook") @@ -41,8 +41,18 @@ class Attachments(Node, frozen=True): class _Person(Node, frozen=True): - name: Optional[str] + name: Optional[Annotated[str, Predicate(lambda s: "/" not in s and "\\" not in s)]] """Full name""" + + @field_validator("name", mode="before") + @classmethod + def convert_name(cls, name: Any, info: FieldValidationInfo): + ctxt = get_internal_validation_context(info.context) + if "original_format" in ctxt and ctxt["original_format"] < (0, 2, 3) and isinstance(name, str): + name = name.replace("/", "").replace("\\", "") + + return name + affiliation: Optional[str] = None """Affiliation""" email: Optional[EmailStr] = None @@ -58,12 +68,12 @@ class _Person(Node, frozen=True): class Author(_Person, frozen=True): - name: str + name: Annotated[str, Predicate(lambda s: "/" not in s and "\\" not in s)] github_user: Optional[str] = None class Maintainer(_Person, frozen=True): - name: Optional[str] = None + name: Optional[Annotated[str, Predicate(lambda s: "/" not in s and "\\" not in s)]] = None github_user: str @@ -97,6 +107,17 @@ class CiteEntry(Node, frozen=True): """A digital object identifier (DOI) is the prefered citation reference. See https://www.doi.org/ for details. (alternatively specify `url`)""" + @field_validator("doi", mode="before") + @classmethod + def accept_prefixed_doi(cls, doi: Any) -> Any: + if isinstance(doi, str): + for doi_prefix in ("https://doi.org/", "http://dx.doi.org/"): + if doi.startswith(doi_prefix): + doi = doi[len(doi_prefix) :] + break + + return doi + url: Optional[str] = None """URL to cite (preferably specify a `doi` instead)""" @@ -165,6 +186,15 @@ class GenericBaseNoSource(ResourceDescriptionBase, frozen=True): authors: Annotated[Tuple[Author, ...], warn(MinLen(1), "No author specified.")] = () """The authors are the creators of the RDF and the primary points of contact.""" + @field_validator("authors", mode="before") + @classmethod + def accept_author_strings(cls, authors: Union[Any, Sequence[Any]]) -> Any: + """we unofficially accept strings as author entries""" + if isinstance(authors, Sequence): + authors = [{"name": a} if isinstance(a, str) else a for a in authors] + + return authors + attachments: Optional[Attachments] = None """file and other attachments"""