diff --git a/changes/445.feature.rst b/changes/445.feature.rst new file mode 100644 index 00000000..05536234 --- /dev/null +++ b/changes/445.feature.rst @@ -0,0 +1 @@ +Start versioning files by allows Node instances to use multiple versions of tags. diff --git a/pyproject.toml b/pyproject.toml index 070d823b..908d9f1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "numpy >=1.24", "astropy >=5.3.0", # "rad >=0.22.0, <0.23.0", - "rad @ git+https://github.com/spacetelescope/rad.git", + "rad @ git+https://github.com/braingram/rad.git@versioned_demo", "asdf-standard >=1.1.0", ] dynamic = ["version"] diff --git a/src/roman_datamodels/maker_utils/_common_meta.py b/src/roman_datamodels/maker_utils/_common_meta.py index 62e80155..ec279a62 100644 --- a/src/roman_datamodels/maker_utils/_common_meta.py +++ b/src/roman_datamodels/maker_utils/_common_meta.py @@ -307,6 +307,7 @@ def mk_l2_cal_step(**kwargs): l2calstep["saturation"] = kwargs.get("saturation", "INCOMPLETE") l2calstep["skymatch"] = kwargs.get("skymatch", "INCOMPLETE") l2calstep["tweakreg"] = kwargs.get("tweakreg", "INCOMPLETE") + l2calstep["two_step"] = kwargs.get("two_step", "INCOMPLETE") return l2calstep diff --git a/src/roman_datamodels/stnode/_converters.py b/src/roman_datamodels/stnode/_converters.py index 43e33582..7144f0e5 100644 --- a/src/roman_datamodels/stnode/_converters.py +++ b/src/roman_datamodels/stnode/_converters.py @@ -5,7 +5,14 @@ from asdf.extension import Converter, ManifestExtension from astropy.time import Time -from ._registry import LIST_NODE_CLASSES_BY_TAG, NODE_CONVERTERS, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG +from ._registry import ( + LIST_NODE_CLASSES_BY_PATTERN, + NODE_CLASSES_BY_TAG, + NODE_CONVERTERS, + OBJECT_NODE_CLASSES_BY_PATTERN, + SCALAR_NODE_CLASSES_BY_PATTERN, +) +from ._stnode import _MANIFESTS __all__ = [ "NODE_EXTENSIONS", @@ -42,11 +49,11 @@ class TaggedObjectNodeConverter(_RomanConverter): @property def tags(self): - return list(OBJECT_NODE_CLASSES_BY_TAG.keys()) + return list(OBJECT_NODE_CLASSES_BY_PATTERN.keys()) @property def types(self): - return list(OBJECT_NODE_CLASSES_BY_TAG.values()) + return list(OBJECT_NODE_CLASSES_BY_PATTERN.values()) def select_tag(self, obj, tags, ctx): return obj.tag @@ -55,7 +62,7 @@ def to_yaml_tree(self, obj, tag, ctx): return dict(obj._data) def from_yaml_tree(self, node, tag, ctx): - return OBJECT_NODE_CLASSES_BY_TAG[tag](node) + return NODE_CLASSES_BY_TAG[tag](node) class TaggedListNodeConverter(_RomanConverter): @@ -65,11 +72,11 @@ class TaggedListNodeConverter(_RomanConverter): @property def tags(self): - return list(LIST_NODE_CLASSES_BY_TAG.keys()) + return list(LIST_NODE_CLASSES_BY_PATTERN.keys()) @property def types(self): - return list(LIST_NODE_CLASSES_BY_TAG.values()) + return list(LIST_NODE_CLASSES_BY_PATTERN.values()) def select_tag(self, obj, tags, ctx): return obj.tag @@ -78,7 +85,7 @@ def to_yaml_tree(self, obj, tag, ctx): return list(obj) def from_yaml_tree(self, node, tag, ctx): - return LIST_NODE_CLASSES_BY_TAG[tag](node) + return NODE_CLASSES_BY_TAG[tag](node) class TaggedScalarNodeConverter(_RomanConverter): @@ -88,37 +95,30 @@ class TaggedScalarNodeConverter(_RomanConverter): @property def tags(self): - return list(SCALAR_NODE_CLASSES_BY_TAG.keys()) + return list(SCALAR_NODE_CLASSES_BY_PATTERN.keys()) @property def types(self): - return list(SCALAR_NODE_CLASSES_BY_TAG.values()) + return list(SCALAR_NODE_CLASSES_BY_PATTERN.values()) def select_tag(self, obj, tags, ctx): return obj.tag def to_yaml_tree(self, obj, tag, ctx): - from ._stnode import FileDate, FpsFileDate, TvacFileDate - node = obj.__class__.__bases__[0](obj) - if tag in (FileDate._tag, FpsFileDate._tag, TvacFileDate._tag): + if "file_date" in tag: converter = ctx.extension_manager.get_converter_for_type(type(node)) node = converter.to_yaml_tree(node, tag, ctx) return node def from_yaml_tree(self, node, tag, ctx): - from ._stnode import FileDate, FpsFileDate, TvacFileDate - - if tag in (FileDate._tag, FpsFileDate._tag, TvacFileDate._tag): + if "file_date" in tag: converter = ctx.extension_manager.get_converter_for_type(Time) node = converter.from_yaml_tree(node, tag, ctx) - - return SCALAR_NODE_CLASSES_BY_TAG[tag](node) + return NODE_CLASSES_BY_TAG[tag](node) # Create the ASDF extension for the STNode classes. -NODE_EXTENSIONS = [ - ManifestExtension.from_uri("asdf://stsci.edu/datamodels/roman/manifests/datamodels-1.0", converters=NODE_CONVERTERS.values()), -] +NODE_EXTENSIONS = [ManifestExtension.from_uri(manifest["id"], converters=NODE_CONVERTERS.values()) for manifest in _MANIFESTS] diff --git a/src/roman_datamodels/stnode/_factories.py b/src/roman_datamodels/stnode/_factories.py index 8b4ade51..5c54f915 100644 --- a/src/roman_datamodels/stnode/_factories.py +++ b/src/roman_datamodels/stnode/_factories.py @@ -3,59 +3,23 @@ These are used to dynamically create classes from the RAD manifest. """ -import importlib.resources - -import yaml from astropy.time import Time -from rad import resources from . import _mixins from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode, name_from_tag_uri __all__ = ["stnode_factory"] -# Map of scalar types in the schemas to the python types -SCALAR_TYPE_MAP = { - "string": str, - "http://stsci.edu/schemas/asdf/time/time-1.1.0": Time, +# Map of scalar types by pattern (str is default) +_SCALAR_TYPE_BY_PATTERN = { + "asdf://stsci.edu/datamodels/roman/tags/file_date-*": Time, + "asdf://stsci.edu/datamodels/roman/tags/fps/file_date-*": Time, + "asdf://stsci.edu/datamodels/roman/tags/tvac/file_date-*": Time, +} +# Map of node types by pattern (TaggedObjectNode is default) +_NODE_TYPE_BY_PATTERN = { + "asdf://stsci.edu/datamodels/roman/tags/cal_logs-*": TaggedListNode, } - -BASE_SCHEMA_PATH = importlib.resources.files(resources) / "schemas" - - -def load_schema_from_uri(schema_uri): - """ - Load the actual schema from the rad resources directly (outside ASDF) - Outside ASDF because this has to occur before the ASDF extensions are - registered. - - Parameters - ---------- - schema_uri : str - The schema_uri found in the RAD manifest - - Returns - ------- - yaml library dictionary from the schema - """ - filename = f"{schema_uri.split('/')[-1]}.yaml" - - if "reference_files" in schema_uri: - schema_path = BASE_SCHEMA_PATH / "reference_files" / filename - elif "/fps/tagged_scalars" in schema_uri: - schema_path = BASE_SCHEMA_PATH / "fps/tagged_scalars" / filename - elif "/fps/" in schema_uri: - schema_path = BASE_SCHEMA_PATH / "fps" / filename - elif "/tvac/tagged_scalars" in schema_uri: - schema_path = BASE_SCHEMA_PATH / "tvac/tagged_scalars" / filename - elif "/tvac/" in schema_uri: - schema_path = BASE_SCHEMA_PATH / "tvac" / filename - elif "tagged_scalars" in schema_uri: - schema_path = BASE_SCHEMA_PATH / "tagged_scalars" / filename - else: - schema_path = BASE_SCHEMA_PATH / filename - - return yaml.safe_load(schema_path.read_bytes()) def class_name_from_tag_uri(tag_uri): @@ -79,94 +43,83 @@ def class_name_from_tag_uri(tag_uri): return class_name -def docstring_from_tag(tag): +def docstring_from_tag(tag_def): """ Read the docstring (if it exists) from the RAD manifest and generate a docstring for the dynamically generated class. Parameters ---------- - tag: dict + tag_def: dict A tag entry from the RAD manifest Returns ------- A docstring for the class based on the tag """ - docstring = f"{tag['description']}\n\n" if "description" in tag else "" + docstring = f"{tag_def['description']}\n\n" if "description" in tag_def else "" - return docstring + f"Class generated from tag '{tag['tag_uri']}'" + return docstring + f"Class generated from tag '{tag_def['tag_uri']}'" -def scalar_factory(tag): +def scalar_factory(pattern, tag_def): """ Factory to create a TaggedScalarNode class from a tag Parameters ---------- - tag: dict + pattern: str + A tag pattern/wildcard + + tag_def: dict A tag entry from the RAD manifest Returns ------- A dynamically generated TaggedScalarNode subclass """ - class_name = class_name_from_tag_uri(tag["tag_uri"]) - schema = load_schema_from_uri(tag["schema_uri"]) + class_name = class_name_from_tag_uri(pattern) # TaggedScalarNode subclasses are really subclasses of the type of the scalar, # with the TaggedScalarNode as a mixin. This is because the TaggedScalarNode # is supposed to be the scalar, but it needs to be serializable under a specific # ASDF tag. - # SCALAR_TYPE_MAP will need to be updated as new wrappers of scalar types are added + # _SCALAR_TYPE_BY_PATTERN will need to be updated as new wrappers of scalar types are added # to the RAD manifest. - if "type" in schema: - type_ = schema["type"] - elif "allOf" in schema: - type_ = schema["allOf"][0]["$ref"] - else: - raise RuntimeError(f"Unknown schema type: {schema}") + # assume everything is a string if not otherwise defined + type_ = _SCALAR_TYPE_BY_PATTERN.get(pattern, str) return type( class_name, - (SCALAR_TYPE_MAP[type_], TaggedScalarNode), - {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring_from_tag(tag)}, + (type_, TaggedScalarNode), + { + "_pattern": pattern, + "_default_tag": tag_def["tag_uri"], + "__module__": "roman_datamodels.stnode", + "__doc__": docstring_from_tag(tag_def), + }, ) -def node_factory(tag): +def node_factory(pattern, tag_def): """ Factory to create a TaggedObjectNode or TaggedListNode class from a tag Parameters ---------- - tag: dict + pattern: str + A tag pattern/wildcard + + tag_def: dict A tag entry from the RAD manifest Returns ------- A dynamically generated TaggedObjectNode or TaggedListNode subclass """ - class_name = class_name_from_tag_uri(tag["tag_uri"]) - schema = load_schema_from_uri(tag["schema_uri"]) - - if "type" in schema: - # Determine if the class is a TaggedObjectNode or TaggedListNode based on the - # type defined in the schema: - # - TaggedObjectNode if type is "object" - # - TaggedListNode if type is "array" (array in jsonschema represents Python list) - if schema["type"] == "object": - class_type = TaggedObjectNode - elif schema["type"] == "array": - class_type = TaggedListNode - else: - raise RuntimeError(f"Unknown schema type: {schema['type']}") - # Use of allOf in the schema indicates that the class is a TaggedObjectNode - # which is "extending" another class. - elif "allOf" in schema: - class_type = TaggedObjectNode - else: - raise RuntimeError(f"Unknown schema type for: {tag['schema_uri']}") + class_name = class_name_from_tag_uri(pattern) + + class_type = _NODE_TYPE_BY_PATTERN.get(pattern, TaggedObjectNode) # In special cases one may need to add additional features to a tagged node class. # This is done by creating a mixin class with the name Mixin in _mixins.py @@ -179,17 +132,25 @@ def node_factory(tag): return type( class_name, class_type, - {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring_from_tag(tag)}, + { + "_pattern": pattern, + "_default_tag": tag_def["tag_uri"], + "__module__": "roman_datamodels.stnode", + "__doc__": docstring_from_tag(tag_def), + }, ) -def stnode_factory(tag): +def stnode_factory(pattern, tag_def): """ Construct a tagged STNode class from a tag Parameters ---------- - tag: dict + pattern: str + A tag pattern/wildcard + + tag_def: dict A tag entry from the RAD manifest Returns @@ -198,7 +159,7 @@ def stnode_factory(tag): """ # TaggedScalarNodes are a special case because they are not a subclass of a # _node class, but rather a subclass of the type of the scalar. - if "tagged_scalar" in tag["schema_uri"]: - return scalar_factory(tag) + if "tagged_scalar" in tag_def["schema_uri"]: + return scalar_factory(pattern, tag_def) else: - return node_factory(tag) + return node_factory(pattern, tag_def) diff --git a/src/roman_datamodels/stnode/_node.py b/src/roman_datamodels/stnode/_node.py index 669317d2..6b43ffd2 100644 --- a/src/roman_datamodels/stnode/_node.py +++ b/src/roman_datamodels/stnode/_node.py @@ -100,7 +100,7 @@ class DNode(MutableMapping): Base class describing all "object" (dict-like) data nodes for STNode classes. """ - _tag = None + _pattern = None _ctx = None def __init__(self, node=None, parent=None, name=None): @@ -311,7 +311,7 @@ class LNode(UserList): Base class describing all "array" (list-like) data nodes for STNode classes. """ - _tag = None + _pattern = None def __init__(self, node=None): if node is None: diff --git a/src/roman_datamodels/stnode/_registry.py b/src/roman_datamodels/stnode/_registry.py index 13562e2f..92105fca 100644 --- a/src/roman_datamodels/stnode/_registry.py +++ b/src/roman_datamodels/stnode/_registry.py @@ -4,8 +4,9 @@ whenever they generated. """ -OBJECT_NODE_CLASSES_BY_TAG = {} -LIST_NODE_CLASSES_BY_TAG = {} -SCALAR_NODE_CLASSES_BY_TAG = {} +OBJECT_NODE_CLASSES_BY_PATTERN = {} +LIST_NODE_CLASSES_BY_PATTERN = {} +SCALAR_NODE_CLASSES_BY_PATTERN = {} SCALAR_NODE_CLASSES_BY_KEY = {} NODE_CONVERTERS = {} +NODE_CLASSES_BY_TAG = {} diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index f829a60c..28836df1 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -12,7 +12,12 @@ from rad import resources from ._factories import stnode_factory -from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG +from ._registry import ( + LIST_NODE_CLASSES_BY_PATTERN, + NODE_CLASSES_BY_TAG, + OBJECT_NODE_CLASSES_BY_PATTERN, + SCALAR_NODE_CLASSES_BY_PATTERN, +) __all__ = [ "NODE_CLASSES", @@ -22,32 +27,42 @@ # Load the manifest directly from the rad resources and not from ASDF. # This is because the ASDF extensions have to be created before they can be registered # and this module creates the classes used by the ASDF extension. -DATAMODELS_MANIFEST_PATH = importlib.resources.files(resources) / "manifests" / "datamodels-1.0.yaml" -DATAMODELS_MANIFEST = yaml.safe_load(DATAMODELS_MANIFEST_PATH.read_bytes()) +_MANIFEST_DIR = importlib.resources.files(resources) / "manifests" +# sort manifests by version (newest first) +_MANIFEST_PATHS = sorted([path for path in _MANIFEST_DIR.glob("*.yaml")], reverse=True) +_MANIFESTS = [yaml.safe_load(path.read_bytes()) for path in _MANIFEST_PATHS] -def _factory(tag): +def _factory(pattern, tag_def): """ Wrap the __all__ append and class creation in a function to avoid the linter getting upset """ - cls = stnode_factory(tag) + cls = stnode_factory(pattern, tag_def) class_name = cls.__name__ globals()[class_name] = cls # Add to namespace of module __all__.append(class_name) # add to __all__ so it's imported with `from . import *` + return cls # Main dynamic class creation loop # Reads each tag entry from the manifest and creates a class for it -for tag in DATAMODELS_MANIFEST["tags"]: - _factory(tag) +_generated = {} +for manifest in _MANIFESTS: + for tag_def in manifest["tags"]: + # make pattern from tag + base, _ = tag_def["tag_uri"].rsplit("-", maxsplit=1) + pattern = f"{base}-*" + if pattern not in _generated: + _generated[pattern] = _factory(pattern, tag_def) + NODE_CLASSES_BY_TAG[tag_def["tag_uri"]] = _generated[pattern] # List of node classes made available by this library. # This is part of the public API. NODE_CLASSES = ( - list(OBJECT_NODE_CLASSES_BY_TAG.values()) - + list(LIST_NODE_CLASSES_BY_TAG.values()) - + list(SCALAR_NODE_CLASSES_BY_TAG.values()) + list(OBJECT_NODE_CLASSES_BY_PATTERN.values()) + + list(LIST_NODE_CLASSES_BY_PATTERN.values()) + + list(SCALAR_NODE_CLASSES_BY_PATTERN.values()) ) diff --git a/src/roman_datamodels/stnode/_tagged.py b/src/roman_datamodels/stnode/_tagged.py index 80fb0f4e..a17397f8 100644 --- a/src/roman_datamodels/stnode/_tagged.py +++ b/src/roman_datamodels/stnode/_tagged.py @@ -10,10 +10,10 @@ from ._node import DNode, LNode from ._registry import ( - LIST_NODE_CLASSES_BY_TAG, - OBJECT_NODE_CLASSES_BY_TAG, + LIST_NODE_CLASSES_BY_PATTERN, + OBJECT_NODE_CLASSES_BY_PATTERN, SCALAR_NODE_CLASSES_BY_KEY, - SCALAR_NODE_CLASSES_BY_TAG, + SCALAR_NODE_CLASSES_BY_PATTERN, ) __all__ = [ @@ -65,14 +65,19 @@ class TaggedObjectNode(DNode): def __init_subclass__(cls, **kwargs) -> None: """ - Register any subclasses of this class in the OBJECT_NODE_CLASSES_BY_TAG + Register any subclasses of this class in the OBJECT_NODE_CLASSES_BY_PATTERN registry. """ super().__init_subclass__(**kwargs) if cls.__name__ != "TaggedObjectNode": - if cls._tag in OBJECT_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedObjectNode class for tag '{cls._tag}' has been defined twice") - OBJECT_NODE_CLASSES_BY_TAG[cls._tag] = cls + if cls._pattern in OBJECT_NODE_CLASSES_BY_PATTERN: + raise RuntimeError(f"TaggedObjectNode class for tag '{cls._pattern}' has been defined twice") + OBJECT_NODE_CLASSES_BY_PATTERN[cls._pattern] = cls + + @property + def _tag(self): + # _tag is required by asdf to allow __asdf_traverse__ + return getattr(self, "_read_tag", self._default_tag) @property def tag(self): @@ -85,7 +90,7 @@ def _schema(self): def get_schema(self): """Retrieve the schema associated with this tag""" - return get_schema_from_tag(self.ctx, self._tag) + return get_schema_from_tag(self.ctx, self.tag) class TaggedListNode(LNode): @@ -97,14 +102,19 @@ class TaggedListNode(LNode): def __init_subclass__(cls, **kwargs) -> None: """ - Register any subclasses of this class in the LIST_NODE_CLASSES_BY_TAG + Register any subclasses of this class in the LIST_NODE_CLASSES_BY_PATTERN registry. """ super().__init_subclass__(**kwargs) if cls.__name__ != "TaggedListNode": - if cls._tag in LIST_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedListNode class for tag '{cls._tag}' has been defined twice") - LIST_NODE_CLASSES_BY_TAG[cls._tag] = cls + if cls._pattern in LIST_NODE_CLASSES_BY_PATTERN: + raise RuntimeError(f"TaggedListNode class for tag '{cls._pattern}' has been defined twice") + LIST_NODE_CLASSES_BY_PATTERN[cls._pattern] = cls + + @property + def _tag(self): + # _tag is required by asdf to allow __asdf_traverse__ + return getattr(self, "_read_tag", self._default_tag) @property def tag(self): @@ -119,20 +129,20 @@ class TaggedScalarNode: These will all be in the tagged_scalars directory. """ - _tag = None + _pattern = None _ctx = None def __init_subclass__(cls, **kwargs) -> None: """ - Register any subclasses of this class in the SCALAR_NODE_CLASSES_BY_TAG + Register any subclasses of this class in the SCALAR_NODE_CLASSES_BY_PATTERN and SCALAR_NODE_CLASSES_BY_KEY registry. """ super().__init_subclass__(**kwargs) if cls.__name__ != "TaggedScalarNode": - if cls._tag in SCALAR_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedScalarNode class for tag '{cls._tag}' has been defined twice") - SCALAR_NODE_CLASSES_BY_TAG[cls._tag] = cls - SCALAR_NODE_CLASSES_BY_KEY[name_from_tag_uri(cls._tag)] = cls + if cls._pattern in SCALAR_NODE_CLASSES_BY_PATTERN: + raise RuntimeError(f"TaggedScalarNode class for tag '{cls._pattern}' has been defined twice") + SCALAR_NODE_CLASSES_BY_PATTERN[cls._pattern] = cls + SCALAR_NODE_CLASSES_BY_KEY[name_from_tag_uri(cls._pattern)] = cls @property def ctx(self): @@ -143,16 +153,21 @@ def ctx(self): def __asdf_traverse__(self): return self + @property + def _tag(self): + # _tag is required by asdf to allow __asdf_traverse__ + return getattr(self, "_read_tag", self._default_tag) + @property def tag(self): return self._tag @property def key(self): - return name_from_tag_uri(self._tag) + return name_from_tag_uri(self.tag) def get_schema(self): - return get_schema_from_tag(self.ctx, self._tag) + return get_schema_from_tag(self.ctx, self.tag) def copy(self): return copy.copy(self) diff --git a/tests/conftest.py b/tests/conftest.py index 430882ab..eeeb4344 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,13 @@ import os -import asdf import pytest -import yaml -MANIFEST = yaml.safe_load(asdf.get_config().resource_manager["asdf://stsci.edu/datamodels/roman/manifests/datamodels-1.0"]) +from roman_datamodels.stnode._stnode import _MANIFESTS as MANIFESTS -@pytest.fixture(scope="session") -def manifest(): - return MANIFEST +@pytest.fixture(scope="session", params=MANIFESTS) +def manifest(request): + return request.param @pytest.fixture(scope="function") diff --git a/tests/test_models.py b/tests/test_models.py index 14a5e601..b2b7e7d4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,7 +14,7 @@ from roman_datamodels import maker_utils as utils from roman_datamodels.testing import assert_node_equal -from .conftest import MANIFEST +from .conftest import MANIFESTS EXPECTED_COMMON_REFERENCE = {"$ref": "ref_common-1.0.0"} @@ -34,12 +34,13 @@ def datamodel_names(): names = [] extension_manager = asdf.AsdfFile().extension_manager - for tag in MANIFEST["tags"]: - schema_uri = extension_manager.get_tag_definition(tag["tag_uri"]).schema_uris[0] - schema = asdf.schema.load_schema(schema_uri, resolve_references=True) + for manifest in MANIFESTS: + for tag in manifest["tags"]: + schema_uri = extension_manager.get_tag_definition(tag["tag_uri"]).schema_uris[0] + schema = asdf.schema.load_schema(schema_uri, resolve_references=True) - if "datamodel_name" in schema: - names.append(schema["datamodel_name"]) + if "datamodel_name" in schema: + names.append(schema["datamodel_name"]) return names diff --git a/tests/test_stnode.py b/tests/test_stnode.py index a9364774..9b381379 100644 --- a/tests/test_stnode.py +++ b/tests/test_stnode.py @@ -9,18 +9,27 @@ from roman_datamodels.maker_utils._base import NOFN, NONUM, NOSTR from roman_datamodels.testing import assert_node_equal, assert_node_is_copy, wraps_hashable -from .conftest import MANIFEST +from .conftest import MANIFESTS -@pytest.mark.parametrize("tag", MANIFEST["tags"]) -def test_generated_node_classes(tag): - class_name = stnode._factories.class_name_from_tag_uri(tag["tag_uri"]) +@pytest.mark.parametrize("tag_def", [tag_def for manifest in MANIFESTS for tag_def in manifest["tags"]]) +def test_tag_has_node_class(tag_def): + class_name = stnode._factories.class_name_from_tag_uri(tag_def["tag_uri"]) node_class = getattr(stnode, class_name) + assert asdf.util.uri_match(node_class._pattern, tag_def["tag_uri"]) + if node_class._default_tag == tag_def["tag_uri"]: + assert tag_def["description"] in node_class.__doc__ + assert tag_def["tag_uri"] in node_class.__doc__ + else: + default_tag_version = node_class._default_tag.rsplit("-", maxsplit=1)[1] + tag_def_version = tag_def["tag_uri"].rsplit("-", maxsplit=1)[1] + assert asdf.versioning.Version(default_tag_version) > asdf.versioning.Version(tag_def_version) + + +@pytest.mark.parametrize("node_class", stnode.NODE_CLASSES) +def test_node_classes_available_via_stnode(node_class): assert issubclass(node_class, stnode.TaggedObjectNode | stnode.TaggedListNode | stnode.TaggedScalarNode) - assert node_class._tag == tag["tag_uri"] - assert tag["description"] in node_class.__doc__ - assert tag["tag_uri"] in node_class.__doc__ assert node_class.__module__ == stnode.__name__ assert hasattr(stnode, node_class.__name__)