Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up some parts of datamodels._core #251

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
- Update minimum version of numpy to 1.22 as this is the oldest version of numpy
which is currently supported. [#258]

- Fix the initialization of empty DataModels and clean up the datamodel core. [#251]

0.17.1 (2023-08-03)
===================

Expand Down
113 changes: 63 additions & 50 deletions src/roman_datamodels/datamodels/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import abc
import copy
import datetime
import functools
import os
import os.path
import sys
Expand All @@ -27,6 +28,26 @@
MODEL_REGISTRY = {}


def _set_default_asdf(func):
"""
Decorator which ensures that a DataModel has an asdf file available for use
if required
"""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self._asdf is None:
try:
with validate.nuke_validation():
self._asdf = asdf.AsdfFile({"roman": self._instance})
except ValidationError as err:
raise ValueError(f"DataModel needs to have all its data flushed out before calling {func.__name__}") from err

return func(self, *args, **kwargs)

return wrapper


class DataModel(abc.ABC):
"""Base class for all top level datamodels"""

Expand Down Expand Up @@ -76,56 +97,57 @@
self._shape = None
self._instance = None
self._asdf = None

if isinstance(init, stnode.TaggedObjectNode):
if not isinstance(self, MODEL_REGISTRY.get(init.__class__)):
expected = {mdl: node for node, mdl in MODEL_REGISTRY.items()}[self.__class__].__name__
raise ValidationError(
f"TaggedObjectNode: {init.__class__.__name__} is not of the type expected. Expected {expected}"
)
with validate.nuke_validation():
self._instance = init
self._asdf = asdf.AsdfFile({"roman": init})

return

if init is None:
asdffile = self.open_asdf(init=None, **kwargs)
self._instance = self._node_type()

elif isinstance(init, (str, bytes, PurePath)):
if isinstance(init, PurePath):
init = str(init)
if isinstance(init, bytes):
init = init.decode(sys.getfilesystemencoding())
asdffile = self.open_asdf(init, **kwargs)
if not self.check_type(asdffile):

self._asdf = self.open_asdf(init, **kwargs)
if not self.check_type(self._asdf):
raise ValueError(f"ASDF file is not of the type expected. Expected {self.__class__.__name__}")
self._instance = asdffile.tree["roman"]

self._instance = self._asdf.tree["roman"]
elif isinstance(init, asdf.AsdfFile):
asdffile = init
self._asdf = asdffile
self._instance = asdffile.tree["roman"]
elif isinstance(init, stnode.TaggedObjectNode):
if not isinstance(self, MODEL_REGISTRY.get(init.__class__)):
expected = {mdl: node for node, mdl in MODEL_REGISTRY.items()}[self.__class__].__name__
raise ValidationError(
f"TaggedObjectNode: {init.__class__.__name__} is not of the type expected. Expected {expected}"
)
with validate.nuke_validation():
self._instance = init
asdffile = asdf.AsdfFile()
asdffile.tree = {"roman": init}
self._asdf = init

self._instance = self._asdf.tree["roman"]
else:
raise OSError("Argument does not appear to be an ASDF file or TaggedObjectNode.")
self._asdf = asdffile

def check_type(self, asdffile_instance):
def check_type(self, asdf_file):
"""
Subclass is expected to check for proper type of node
"""
if "roman" not in asdffile_instance.tree:
if "roman" not in asdf_file.tree:
raise ValueError('ASDF file does not have expected "roman" attribute')
topnode = asdffile_instance.tree["roman"]
if MODEL_REGISTRY[topnode.__class__] != self.__class__:
return False
return True

return MODEL_REGISTRY[asdf_file.tree["roman"].__class__] == self.__class__

@property
def schema_uri(self):
# Determine the schema corresponding to this model's tag
schema_uri = next(t for t in stnode.NODE_EXTENSIONS[0].tags if t.tag_uri == self._instance._tag).schema_uris[0]
return schema_uri
return next(t for t in stnode.NODE_EXTENSIONS[0].tags if t.tag_uri == self._instance._tag).schema_uris[0]

def close(self):
if not self._iscopy:
if self._asdf is not None:
self._asdf.close()
if not (self._iscopy or self._asdf is None):
self._asdf.close()

def __enter__(self):
return self
Expand All @@ -147,15 +169,13 @@
@staticmethod
def clone(target, source, deepcopy=False, memo=None):
if deepcopy:
instance = copy.deepcopy(source._instance, memo=memo)
target._asdf = source._asdf.copy()
target._instance = instance
target._iscopy = True
target._instance = copy.deepcopy(source._instance, memo=memo)
else:
target._asdf = source._asdf
target._instance = source._instance
target._iscopy = True

target._iscopy = True
target._files_to_close = []
target._shape = source._shape
target._ctx = target
Expand Down Expand Up @@ -183,11 +203,7 @@

def open_asdf(self, init=None, **kwargs):
with validate.nuke_validation():
if isinstance(init, str):
asdffile = asdf.open(init, **kwargs)
else:
asdffile = asdf.AsdfFile(init, **kwargs)
return asdffile
return asdf.open(init, **kwargs) if isinstance(init, str) else asdf.AsdfFile(init, **kwargs)

def to_asdf(self, init, *args, **kwargs):
with validate.nuke_validation():
Expand All @@ -202,19 +218,15 @@
This is intended to be overridden in the subclasses if the
primary array's name is not "data".
"""
if hasattr(self, "data"):
primary_array_name = "data"
else:
primary_array_name = ""
return primary_array_name
return "data" if hasattr(self, "data") else ""

@property
def override_handle(self):
"""override_handle identifies in-memory models where a filepath
would normally be used.
"""
# Arbitrary choice to look something like crds://
return "override://" + self.__class__.__name__
return f"override://{self.__class__.__name__}"

Check warning on line 229 in src/roman_datamodels/datamodels/_core.py

View check run for this annotation

Codecov / codecov/patch

src/roman_datamodels/datamodels/_core.py#L229

Added line #L229 was not covered by tests

@property
def shape(self):
Expand Down Expand Up @@ -266,10 +278,9 @@
return str(val)
return val

if include_arrays:
return {"roman." + key: convert_val(val) for (key, val) in self.items()}
else:
return {"roman." + key: convert_val(val) for (key, val) in self.items() if not isinstance(val, np.ndarray)}
return {
f"roman.{key}": convert_val(val) for (key, val) in self.items() if include_arrays or not isinstance(val, np.ndarray)
}

def items(self):
"""
Expand Down Expand Up @@ -305,24 +316,26 @@
-------
dict
"""
crds_header = {
return {
key: val
for key, val in self.to_flat_dict(include_arrays=False).items()
if isinstance(val, (str, int, float, complex, bool))
}
return crds_header

def validate(self):
"""
Re-validate the model instance against the tags
"""
validate.value_change(self._instance, pass_invalid_values=False, strict_validation=True)

@_set_default_asdf
def info(self, *args, **kwargs):
return self._asdf.info(*args, **kwargs)

@_set_default_asdf
def search(self, *args, **kwargs):
return self._asdf.search(*args, **kwargs)

@_set_default_asdf
def schema_info(self, *args, **kwargs):
return self._asdf.schema_info(*args, **kwargs)
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import asdf
import pytest
import yaml
Expand All @@ -8,3 +10,23 @@
@pytest.fixture(scope="session")
def manifest():
return MANIFEST


@pytest.fixture(scope="function")
def nuke_env_var(request):
from roman_datamodels import validate

assert os.getenv(validate.ROMAN_VALIDATE) == "true"
os.environ[validate.ROMAN_VALIDATE] = request.param
yield request.param, request.param.lower() in ["true", "yes", "1"]
os.environ[validate.ROMAN_VALIDATE] = "true"


@pytest.fixture(scope="function")
def nuke_env_strict_var(request):
from roman_datamodels import validate

assert os.getenv(validate.ROMAN_STRICT_VALIDATION) == "true"
os.environ[validate.ROMAN_STRICT_VALIDATION] = request.param
yield request.param
os.environ[validate.ROMAN_STRICT_VALIDATION] = "true"
63 changes: 62 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from roman_datamodels import datamodels
from roman_datamodels import maker_utils as utils
from roman_datamodels import stnode
from roman_datamodels import stnode, validate
from roman_datamodels.testing import assert_node_equal

from .conftest import MANIFEST
Expand Down Expand Up @@ -63,6 +63,67 @@ def test_model_schemas(node, model):
asdf.schema.load_schema(instance.schema_uri)


@pytest.mark.parametrize("node, model", datamodels.MODEL_REGISTRY.items())
@pytest.mark.parametrize("method", ["info", "search", "schema_info"])
@pytest.mark.parametrize("nuke_env_var", ["true", "false"], indirect=True)
def test_empty_model_asdf_operations(node, model, method, nuke_env_var):
"""
Test the decorator for asdf operations on models when the model is left truly empty.
"""
mdl = model()
assert isinstance(mdl._instance, node)

# Check that the model does not have the asdf attribute set.
assert mdl._asdf is None

# Depending on the state for nuke_validation we either expect an error or a
# warning to be raised.
# - error: when nuke_env_var == true
# - warning: when nuke_env_var == false
msg = f"DataModel needs to have all its data flushed out before calling {method}"
context = pytest.raises(ValueError, match=msg) if nuke_env_var[1] else pytest.warns(validate.ValidationWarning)

# Execute the method we wish to test, and catch the expected error/warning.
with context:
getattr(mdl, method)()

if nuke_env_var[1]:
# If an error is raised (nuke_env_var == true), then the asdf attribute should
# fail to be set.
assert mdl._asdf is None
else:
# In a warning is raised (nuke_env_var == false), then the asdf attribute should
# be set to something.
assert mdl._asdf is not None


@pytest.mark.parametrize("node, model", datamodels.MODEL_REGISTRY.items())
@pytest.mark.parametrize("method", ["info", "search", "schema_info"])
def test_model_asdf_operations(node, model, method):
"""
Test the decorator for asdf operations on models when an empty initial model
which is then filled.
"""
# Create an empty model
mdl = model()
assert isinstance(mdl._instance, node)

# Check there model prior to filling raises an error.
with pytest.raises(ValueError):
getattr(mdl, method)()

# Fill the model with data, but no asdf file is present
mdl._instance = utils.mk_node(node)
assert mdl._asdf is None

# Run the method we wish to test (it should fail with warning or error
# if something is broken)
getattr(mdl, method)()

# Show that mdl._asdf is now set
assert mdl._asdf is not None


# Testing core schema
def test_core_schema(tmp_path):
# Set temporary asdf file
Expand Down
Loading