Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RCAL-977 - Version datamodels #445

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions changes/445.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Start versioning files by allows Node instances to use multiple versions of tags.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ dependencies = [
"gwcs >=0.19.0",
"numpy >=1.24",
"astropy >=5.3.0",
"rad >=0.23.0, <0.24.0",
# "rad >=0.23.0, <0.24.0",
# "rad @ git+https://github.com/spacetelescope/rad.git",
"rad @ git+https://github.com/braingram/rad.git@versioned",
"asdf-standard >=1.1.0",
]
dynamic = ["version"]
Expand Down
59 changes: 26 additions & 33 deletions src/roman_datamodels/stnode/_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from asdf.extension import Converter, ManifestExtension
from astropy.time import Time

from ._registry import LIST_NODE_CLASSES_BY_TAG, NODE_CONVERTERS, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG
from ._registry import (
LIST_NODE_CLASSES_BY_PATTERN,
NODE_CLASSES_BY_TAG,
NODE_CONVERTERS,
OBJECT_NODE_CLASSES_BY_PATTERN,
SCALAR_NODE_CLASSES_BY_PATTERN,
)
from ._stnode import _MANIFESTS

__all__ = [
"NODE_EXTENSIONS",
Expand Down Expand Up @@ -34,6 +41,14 @@ def __init_subclass__(cls, **kwargs) -> None:

NODE_CONVERTERS[cls.__name__] = cls()

def select_tag(self, obj, tags, ctx):
return obj.tag

def from_yaml_tree(self, node, tag, ctx):
obj = NODE_CLASSES_BY_TAG[tag](node)
obj._read_tag = tag
return obj


class TaggedObjectNodeConverter(_RomanConverter):
"""
Expand All @@ -42,21 +57,15 @@ class TaggedObjectNodeConverter(_RomanConverter):

@property
def tags(self):
return list(OBJECT_NODE_CLASSES_BY_TAG.keys())
return list(OBJECT_NODE_CLASSES_BY_PATTERN.keys())

@property
def types(self):
return list(OBJECT_NODE_CLASSES_BY_TAG.values())

def select_tag(self, obj, tags, ctx):
return obj.tag
return list(OBJECT_NODE_CLASSES_BY_PATTERN.values())

def to_yaml_tree(self, obj, tag, ctx):
return dict(obj._data)

def from_yaml_tree(self, node, tag, ctx):
return OBJECT_NODE_CLASSES_BY_TAG[tag](node)


class TaggedListNodeConverter(_RomanConverter):
"""
Expand All @@ -65,21 +74,15 @@ class TaggedListNodeConverter(_RomanConverter):

@property
def tags(self):
return list(LIST_NODE_CLASSES_BY_TAG.keys())
return list(LIST_NODE_CLASSES_BY_PATTERN.keys())

@property
def types(self):
return list(LIST_NODE_CLASSES_BY_TAG.values())

def select_tag(self, obj, tags, ctx):
return obj.tag
return list(LIST_NODE_CLASSES_BY_PATTERN.values())

def to_yaml_tree(self, obj, tag, ctx):
return list(obj)

def from_yaml_tree(self, node, tag, ctx):
return LIST_NODE_CLASSES_BY_TAG[tag](node)


class TaggedScalarNodeConverter(_RomanConverter):
"""
Expand All @@ -88,37 +91,27 @@ class TaggedScalarNodeConverter(_RomanConverter):

@property
def tags(self):
return list(SCALAR_NODE_CLASSES_BY_TAG.keys())
return list(SCALAR_NODE_CLASSES_BY_PATTERN.keys())

@property
def types(self):
return list(SCALAR_NODE_CLASSES_BY_TAG.values())

def select_tag(self, obj, tags, ctx):
return obj.tag
return list(SCALAR_NODE_CLASSES_BY_PATTERN.values())

def to_yaml_tree(self, obj, tag, ctx):
from ._stnode import FileDate, FpsFileDate, TvacFileDate

node = obj.__class__.__bases__[0](obj)

if tag in (FileDate._tag, FpsFileDate._tag, TvacFileDate._tag):
if "file_date" in tag:
converter = ctx.extension_manager.get_converter_for_type(type(node))
node = converter.to_yaml_tree(node, tag, ctx)

return node

def from_yaml_tree(self, node, tag, ctx):
from ._stnode import FileDate, FpsFileDate, TvacFileDate

if tag in (FileDate._tag, FpsFileDate._tag, TvacFileDate._tag):
if "file_date" in tag:
converter = ctx.extension_manager.get_converter_for_type(Time)
node = converter.from_yaml_tree(node, tag, ctx)

return SCALAR_NODE_CLASSES_BY_TAG[tag](node)
return super().from_yaml_tree(node, tag, ctx)


# Create the ASDF extension for the STNode classes.
NODE_EXTENSIONS = [
ManifestExtension.from_uri("asdf://stsci.edu/datamodels/roman/manifests/datamodels-1.0", converters=NODE_CONVERTERS.values()),
]
NODE_EXTENSIONS = [ManifestExtension.from_uri(manifest["id"], converters=NODE_CONVERTERS.values()) for manifest in _MANIFESTS]
141 changes: 51 additions & 90 deletions src/roman_datamodels/stnode/_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,23 @@
These are used to dynamically create classes from the RAD manifest.
"""

import importlib.resources

import yaml
from astropy.time import Time
from rad import resources

from . import _mixins
from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode, name_from_tag_uri

__all__ = ["stnode_factory"]

# Map of scalar types in the schemas to the python types
SCALAR_TYPE_MAP = {
"string": str,
"http://stsci.edu/schemas/asdf/time/time-1.1.0": Time,
# Map of scalar types by pattern (str is default)
_SCALAR_TYPE_BY_PATTERN = {
"asdf://stsci.edu/datamodels/roman/tags/file_date-*": Time,
"asdf://stsci.edu/datamodels/roman/tags/fps/file_date-*": Time,
"asdf://stsci.edu/datamodels/roman/tags/tvac/file_date-*": Time,
}
# Map of node types by pattern (TaggedObjectNode is default)
_NODE_TYPE_BY_PATTERN = {
"asdf://stsci.edu/datamodels/roman/tags/cal_logs-*": TaggedListNode,
}

BASE_SCHEMA_PATH = importlib.resources.files(resources) / "schemas"


def load_schema_from_uri(schema_uri):
"""
Load the actual schema from the rad resources directly (outside ASDF)
Outside ASDF because this has to occur before the ASDF extensions are
registered.

Parameters
----------
schema_uri : str
The schema_uri found in the RAD manifest

Returns
-------
yaml library dictionary from the schema
"""
filename = f"{schema_uri.split('/')[-1]}.yaml"

if "reference_files" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "reference_files" / filename
elif "/fps/tagged_scalars" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "fps/tagged_scalars" / filename
elif "/fps/" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "fps" / filename
elif "/tvac/tagged_scalars" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "tvac/tagged_scalars" / filename
elif "/tvac/" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "tvac" / filename
elif "tagged_scalars" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "tagged_scalars" / filename
else:
schema_path = BASE_SCHEMA_PATH / filename

return yaml.safe_load(schema_path.read_bytes())


def class_name_from_tag_uri(tag_uri):
Expand All @@ -79,94 +43,83 @@ def class_name_from_tag_uri(tag_uri):
return class_name


def docstring_from_tag(tag):
def docstring_from_tag(tag_def):
"""
Read the docstring (if it exists) from the RAD manifest and generate a docstring
for the dynamically generated class.

Parameters
----------
tag: dict
tag_def: dict
A tag entry from the RAD manifest

Returns
-------
A docstring for the class based on the tag
"""
docstring = f"{tag['description']}\n\n" if "description" in tag else ""
docstring = f"{tag_def['description']}\n\n" if "description" in tag_def else ""

return docstring + f"Class generated from tag '{tag['tag_uri']}'"
return docstring + f"Class generated from tag '{tag_def['tag_uri']}'"


def scalar_factory(tag):
def scalar_factory(pattern, tag_def):
"""
Factory to create a TaggedScalarNode class from a tag

Parameters
----------
tag: dict
pattern: str
A tag pattern/wildcard

tag_def: dict
A tag entry from the RAD manifest

Returns
-------
A dynamically generated TaggedScalarNode subclass
"""
class_name = class_name_from_tag_uri(tag["tag_uri"])
schema = load_schema_from_uri(tag["schema_uri"])
class_name = class_name_from_tag_uri(pattern)

# TaggedScalarNode subclasses are really subclasses of the type of the scalar,
# with the TaggedScalarNode as a mixin. This is because the TaggedScalarNode
# is supposed to be the scalar, but it needs to be serializable under a specific
# ASDF tag.
# SCALAR_TYPE_MAP will need to be updated as new wrappers of scalar types are added
# _SCALAR_TYPE_BY_PATTERN will need to be updated as new wrappers of scalar types are added
# to the RAD manifest.
if "type" in schema:
type_ = schema["type"]
elif "allOf" in schema:
type_ = schema["allOf"][0]["$ref"]
else:
raise RuntimeError(f"Unknown schema type: {schema}")
# assume everything is a string if not otherwise defined
type_ = _SCALAR_TYPE_BY_PATTERN.get(pattern, str)

return type(
class_name,
(SCALAR_TYPE_MAP[type_], TaggedScalarNode),
{"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring_from_tag(tag)},
(type_, TaggedScalarNode),
{
"_pattern": pattern,
"_default_tag": tag_def["tag_uri"],
"__module__": "roman_datamodels.stnode",
"__doc__": docstring_from_tag(tag_def),
},
)


def node_factory(tag):
def node_factory(pattern, tag_def):
"""
Factory to create a TaggedObjectNode or TaggedListNode class from a tag

Parameters
----------
tag: dict
pattern: str
A tag pattern/wildcard

tag_def: dict
A tag entry from the RAD manifest

Returns
-------
A dynamically generated TaggedObjectNode or TaggedListNode subclass
"""
class_name = class_name_from_tag_uri(tag["tag_uri"])
schema = load_schema_from_uri(tag["schema_uri"])

if "type" in schema:
# Determine if the class is a TaggedObjectNode or TaggedListNode based on the
# type defined in the schema:
# - TaggedObjectNode if type is "object"
# - TaggedListNode if type is "array" (array in jsonschema represents Python list)
if schema["type"] == "object":
class_type = TaggedObjectNode
elif schema["type"] == "array":
class_type = TaggedListNode
else:
raise RuntimeError(f"Unknown schema type: {schema['type']}")
# Use of allOf in the schema indicates that the class is a TaggedObjectNode
# which is "extending" another class.
elif "allOf" in schema:
class_type = TaggedObjectNode
else:
raise RuntimeError(f"Unknown schema type for: {tag['schema_uri']}")
class_name = class_name_from_tag_uri(pattern)

class_type = _NODE_TYPE_BY_PATTERN.get(pattern, TaggedObjectNode)

# In special cases one may need to add additional features to a tagged node class.
# This is done by creating a mixin class with the name <ClassName>Mixin in _mixins.py
Expand All @@ -179,17 +132,25 @@ def node_factory(tag):
return type(
class_name,
class_type,
{"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring_from_tag(tag)},
{
"_pattern": pattern,
"_default_tag": tag_def["tag_uri"],
"__module__": "roman_datamodels.stnode",
"__doc__": docstring_from_tag(tag_def),
},
)


def stnode_factory(tag):
def stnode_factory(pattern, tag_def):
"""
Construct a tagged STNode class from a tag

Parameters
----------
tag: dict
pattern: str
A tag pattern/wildcard

tag_def: dict
A tag entry from the RAD manifest

Returns
Expand All @@ -198,7 +159,7 @@ def stnode_factory(tag):
"""
# TaggedScalarNodes are a special case because they are not a subclass of a
# _node class, but rather a subclass of the type of the scalar.
if "tagged_scalar" in tag["schema_uri"]:
return scalar_factory(tag)
if "tagged_scalar" in tag_def["schema_uri"]:
return scalar_factory(pattern, tag_def)
else:
return node_factory(tag)
return node_factory(pattern, tag_def)
4 changes: 2 additions & 2 deletions src/roman_datamodels/stnode/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class DNode(MutableMapping):
Base class describing all "object" (dict-like) data nodes for STNode classes.
"""

_tag = None
_pattern = None
_ctx = None

def __init__(self, node=None, parent=None, name=None):
Expand Down Expand Up @@ -311,7 +311,7 @@ class LNode(UserList):
Base class describing all "array" (list-like) data nodes for STNode classes.
"""

_tag = None
_pattern = None

def __init__(self, node=None):
if node is None:
Expand Down
7 changes: 4 additions & 3 deletions src/roman_datamodels/stnode/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
whenever they generated.
"""

OBJECT_NODE_CLASSES_BY_TAG = {}
LIST_NODE_CLASSES_BY_TAG = {}
SCALAR_NODE_CLASSES_BY_TAG = {}
OBJECT_NODE_CLASSES_BY_PATTERN = {}
LIST_NODE_CLASSES_BY_PATTERN = {}
SCALAR_NODE_CLASSES_BY_PATTERN = {}
SCALAR_NODE_CLASSES_BY_KEY = {}
NODE_CONVERTERS = {}
NODE_CLASSES_BY_TAG = {}
Loading
Loading