diff --git a/CHANGES.rst b/CHANGES.rst index 761600e4..6e30b3fe 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,8 @@ - Update minimum version of numpy to 1.22 as this is the oldest version of numpy which is currently supported. [#258] +- Fix the initialization of empty DataModels and clean up the datamodel core. [#251] + 0.17.1 (2023-08-03) =================== diff --git a/src/roman_datamodels/datamodels/_core.py b/src/roman_datamodels/datamodels/_core.py index 8e8f8ae8..53e081e8 100644 --- a/src/roman_datamodels/datamodels/_core.py +++ b/src/roman_datamodels/datamodels/_core.py @@ -10,6 +10,7 @@ import abc import copy import datetime +import functools import os import os.path import sys @@ -27,6 +28,26 @@ MODEL_REGISTRY = {} +def _set_default_asdf(func): + """ + Decorator which ensures that a DataModel has an asdf file available for use + if required + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if self._asdf is None: + try: + with validate.nuke_validation(): + self._asdf = asdf.AsdfFile({"roman": self._instance}) + except ValidationError as err: + raise ValueError(f"DataModel needs to have all its data flushed out before calling {func.__name__}") from err + + return func(self, *args, **kwargs) + + return wrapper + + class DataModel(abc.ABC): """Base class for all top level datamodels""" @@ -76,56 +97,57 @@ def __init__(self, init=None, **kwargs): self._shape = None self._instance = None self._asdf = None + + if isinstance(init, stnode.TaggedObjectNode): + if not isinstance(self, MODEL_REGISTRY.get(init.__class__)): + expected = {mdl: node for node, mdl in MODEL_REGISTRY.items()}[self.__class__].__name__ + raise ValidationError( + f"TaggedObjectNode: {init.__class__.__name__} is not of the type expected. Expected {expected}" + ) + with validate.nuke_validation(): + self._instance = init + self._asdf = asdf.AsdfFile({"roman": init}) + + return + if init is None: - asdffile = self.open_asdf(init=None, **kwargs) + self._instance = self._node_type() + elif isinstance(init, (str, bytes, PurePath)): if isinstance(init, PurePath): init = str(init) if isinstance(init, bytes): init = init.decode(sys.getfilesystemencoding()) - asdffile = self.open_asdf(init, **kwargs) - if not self.check_type(asdffile): + + self._asdf = self.open_asdf(init, **kwargs) + if not self.check_type(self._asdf): raise ValueError(f"ASDF file is not of the type expected. Expected {self.__class__.__name__}") - self._instance = asdffile.tree["roman"] + + self._instance = self._asdf.tree["roman"] elif isinstance(init, asdf.AsdfFile): - asdffile = init - self._asdf = asdffile - self._instance = asdffile.tree["roman"] - elif isinstance(init, stnode.TaggedObjectNode): - if not isinstance(self, MODEL_REGISTRY.get(init.__class__)): - expected = {mdl: node for node, mdl in MODEL_REGISTRY.items()}[self.__class__].__name__ - raise ValidationError( - f"TaggedObjectNode: {init.__class__.__name__} is not of the type expected. Expected {expected}" - ) - with validate.nuke_validation(): - self._instance = init - asdffile = asdf.AsdfFile() - asdffile.tree = {"roman": init} + self._asdf = init + + self._instance = self._asdf.tree["roman"] else: raise OSError("Argument does not appear to be an ASDF file or TaggedObjectNode.") - self._asdf = asdffile - def check_type(self, asdffile_instance): + def check_type(self, asdf_file): """ Subclass is expected to check for proper type of node """ - if "roman" not in asdffile_instance.tree: + if "roman" not in asdf_file.tree: raise ValueError('ASDF file does not have expected "roman" attribute') - topnode = asdffile_instance.tree["roman"] - if MODEL_REGISTRY[topnode.__class__] != self.__class__: - return False - return True + + return MODEL_REGISTRY[asdf_file.tree["roman"].__class__] == self.__class__ @property def schema_uri(self): # Determine the schema corresponding to this model's tag - schema_uri = next(t for t in stnode.NODE_EXTENSIONS[0].tags if t.tag_uri == self._instance._tag).schema_uris[0] - return schema_uri + return next(t for t in stnode.NODE_EXTENSIONS[0].tags if t.tag_uri == self._instance._tag).schema_uris[0] def close(self): - if not self._iscopy: - if self._asdf is not None: - self._asdf.close() + if not (self._iscopy or self._asdf is None): + self._asdf.close() def __enter__(self): return self @@ -147,15 +169,13 @@ def copy(self, deepcopy=True, memo=None): @staticmethod def clone(target, source, deepcopy=False, memo=None): if deepcopy: - instance = copy.deepcopy(source._instance, memo=memo) target._asdf = source._asdf.copy() - target._instance = instance - target._iscopy = True + target._instance = copy.deepcopy(source._instance, memo=memo) else: target._asdf = source._asdf target._instance = source._instance - target._iscopy = True + target._iscopy = True target._files_to_close = [] target._shape = source._shape target._ctx = target @@ -183,11 +203,7 @@ def save(self, path, dir_path=None, *args, **kwargs): def open_asdf(self, init=None, **kwargs): with validate.nuke_validation(): - if isinstance(init, str): - asdffile = asdf.open(init, **kwargs) - else: - asdffile = asdf.AsdfFile(init, **kwargs) - return asdffile + return asdf.open(init, **kwargs) if isinstance(init, str) else asdf.AsdfFile(init, **kwargs) def to_asdf(self, init, *args, **kwargs): with validate.nuke_validation(): @@ -202,11 +218,7 @@ def get_primary_array_name(self): This is intended to be overridden in the subclasses if the primary array's name is not "data". """ - if hasattr(self, "data"): - primary_array_name = "data" - else: - primary_array_name = "" - return primary_array_name + return "data" if hasattr(self, "data") else "" @property def override_handle(self): @@ -214,7 +226,7 @@ def override_handle(self): would normally be used. """ # Arbitrary choice to look something like crds:// - return "override://" + self.__class__.__name__ + return f"override://{self.__class__.__name__}" @property def shape(self): @@ -266,10 +278,9 @@ def convert_val(val): return str(val) return val - if include_arrays: - return {"roman." + key: convert_val(val) for (key, val) in self.items()} - else: - return {"roman." + key: convert_val(val) for (key, val) in self.items() if not isinstance(val, np.ndarray)} + return { + f"roman.{key}": convert_val(val) for (key, val) in self.items() if include_arrays or not isinstance(val, np.ndarray) + } def items(self): """ @@ -305,12 +316,11 @@ def get_crds_parameters(self): ------- dict """ - crds_header = { + return { key: val for key, val in self.to_flat_dict(include_arrays=False).items() if isinstance(val, (str, int, float, complex, bool)) } - return crds_header def validate(self): """ @@ -318,11 +328,14 @@ def validate(self): """ validate.value_change(self._instance, pass_invalid_values=False, strict_validation=True) + @_set_default_asdf def info(self, *args, **kwargs): return self._asdf.info(*args, **kwargs) + @_set_default_asdf def search(self, *args, **kwargs): return self._asdf.search(*args, **kwargs) + @_set_default_asdf def schema_info(self, *args, **kwargs): return self._asdf.schema_info(*args, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 16a3dda1..430882ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import os + import asdf import pytest import yaml @@ -8,3 +10,23 @@ @pytest.fixture(scope="session") def manifest(): return MANIFEST + + +@pytest.fixture(scope="function") +def nuke_env_var(request): + from roman_datamodels import validate + + assert os.getenv(validate.ROMAN_VALIDATE) == "true" + os.environ[validate.ROMAN_VALIDATE] = request.param + yield request.param, request.param.lower() in ["true", "yes", "1"] + os.environ[validate.ROMAN_VALIDATE] = "true" + + +@pytest.fixture(scope="function") +def nuke_env_strict_var(request): + from roman_datamodels import validate + + assert os.getenv(validate.ROMAN_STRICT_VALIDATION) == "true" + os.environ[validate.ROMAN_STRICT_VALIDATION] = request.param + yield request.param + os.environ[validate.ROMAN_STRICT_VALIDATION] = "true" diff --git a/tests/test_models.py b/tests/test_models.py index be9c922b..59b05f98 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,7 +11,7 @@ from roman_datamodels import datamodels from roman_datamodels import maker_utils as utils -from roman_datamodels import stnode +from roman_datamodels import stnode, validate from roman_datamodels.testing import assert_node_equal from .conftest import MANIFEST @@ -63,6 +63,67 @@ def test_model_schemas(node, model): asdf.schema.load_schema(instance.schema_uri) +@pytest.mark.parametrize("node, model", datamodels.MODEL_REGISTRY.items()) +@pytest.mark.parametrize("method", ["info", "search", "schema_info"]) +@pytest.mark.parametrize("nuke_env_var", ["true", "false"], indirect=True) +def test_empty_model_asdf_operations(node, model, method, nuke_env_var): + """ + Test the decorator for asdf operations on models when the model is left truly empty. + """ + mdl = model() + assert isinstance(mdl._instance, node) + + # Check that the model does not have the asdf attribute set. + assert mdl._asdf is None + + # Depending on the state for nuke_validation we either expect an error or a + # warning to be raised. + # - error: when nuke_env_var == true + # - warning: when nuke_env_var == false + msg = f"DataModel needs to have all its data flushed out before calling {method}" + context = pytest.raises(ValueError, match=msg) if nuke_env_var[1] else pytest.warns(validate.ValidationWarning) + + # Execute the method we wish to test, and catch the expected error/warning. + with context: + getattr(mdl, method)() + + if nuke_env_var[1]: + # If an error is raised (nuke_env_var == true), then the asdf attribute should + # fail to be set. + assert mdl._asdf is None + else: + # In a warning is raised (nuke_env_var == false), then the asdf attribute should + # be set to something. + assert mdl._asdf is not None + + +@pytest.mark.parametrize("node, model", datamodels.MODEL_REGISTRY.items()) +@pytest.mark.parametrize("method", ["info", "search", "schema_info"]) +def test_model_asdf_operations(node, model, method): + """ + Test the decorator for asdf operations on models when an empty initial model + which is then filled. + """ + # Create an empty model + mdl = model() + assert isinstance(mdl._instance, node) + + # Check there model prior to filling raises an error. + with pytest.raises(ValueError): + getattr(mdl, method)() + + # Fill the model with data, but no asdf file is present + mdl._instance = utils.mk_node(node) + assert mdl._asdf is None + + # Run the method we wish to test (it should fail with warning or error + # if something is broken) + getattr(mdl, method)() + + # Show that mdl._asdf is now set + assert mdl._asdf is not None + + # Testing core schema def test_core_schema(tmp_path): # Set temporary asdf file diff --git a/tests/test_stnode.py b/tests/test_stnode.py index 0818002c..16f8f68f 100644 --- a/tests/test_stnode.py +++ b/tests/test_stnode.py @@ -181,21 +181,17 @@ def test_set_pattern_properties(): mdl.phot_table.F062.pixelareasr = None -@pytest.fixture(scope="function", params=["true", "yes", "1", "True", "Yes", "TrUe", "YeS", "foo", "Bar", "BaZ"]) -def env_var(request): - assert os.getenv(validate.ROMAN_VALIDATE) == "true" - os.environ[validate.ROMAN_VALIDATE] = request.param - yield request.param, request.param.lower() in ["true", "yes", "1"] - os.environ[validate.ROMAN_VALIDATE] = "true" +VALIDATION_CASES = ("true", "yes", "1", "True", "Yes", "TrUe", "YeS", "foo", "Bar", "BaZ") -def test_will_validate(env_var): +@pytest.mark.parametrize("nuke_env_var", VALIDATION_CASES, indirect=True) +def test_will_validate(nuke_env_var): # Test the fixture passed the value of the environment variable - value = env_var[0] + value = nuke_env_var[0] assert os.getenv(validate.ROMAN_VALIDATE) == value # Test the validate property - truth = env_var[1] + truth = nuke_env_var[1] context = nullcontext() if truth else pytest.warns(validate.ValidationWarning) with context: @@ -217,8 +213,9 @@ def test_will_validate(env_var): assert validate.will_validate() is True -def test_nuke_validation(env_var, tmp_path): - context = pytest.raises(asdf.ValidationError) if env_var[1] else pytest.warns(validate.ValidationWarning) +@pytest.mark.parametrize("nuke_env_var", VALIDATION_CASES, indirect=True) +def test_nuke_validation(nuke_env_var, tmp_path): + context = pytest.raises(asdf.ValidationError) if nuke_env_var[1] else pytest.warns(validate.ValidationWarning) # Create a broken DNode object mdl = maker_utils.mk_wfi_img_photom() @@ -232,7 +229,7 @@ def test_nuke_validation(env_var, tmp_path): mdl.phot_table = "THIS IS NOT VALID" # Break model without outside validation - with nullcontext() if env_var[1] else pytest.warns(validate.ValidationWarning): + with nullcontext() if nuke_env_var[1] else pytest.warns(validate.ValidationWarning): mdl = datamodels.WfiImgPhotomRefModel(maker_utils.mk_wfi_img_photom()) mdl._instance["phot_table"] = "THIS IS NOT VALID" @@ -240,20 +237,20 @@ def test_nuke_validation(env_var, tmp_path): broken_save = tmp_path / "broken_save.asdf" with context: mdl.save(broken_save) - assert os.path.isfile(broken_save) is not env_var[1] + assert os.path.isfile(broken_save) is not nuke_env_var[1] broken_to_asdf = tmp_path / "broken_to_asdf.asdf" with context: mdl.to_asdf(broken_to_asdf) - assert os.path.isfile(broken_to_asdf) is not env_var[1] + assert os.path.isfile(broken_to_asdf) is not nuke_env_var[1] # Create a broken file for reading if needed - if env_var[1]: + if nuke_env_var[1]: os.environ[validate.ROMAN_VALIDATE] = "false" with pytest.warns(validate.ValidationWarning): mdl.save(broken_save) mdl.to_asdf(broken_to_asdf) - os.environ[validate.ROMAN_VALIDATE] = env_var[0] + os.environ[validate.ROMAN_VALIDATE] = nuke_env_var[0] # Read broken files with datamodel object with context: @@ -270,32 +267,25 @@ def test_nuke_validation(env_var, tmp_path): pass -@pytest.fixture(scope="function", params=["true", "yes", "1", "True", "Yes", "TrUe", "YeS", "foo", "Bar", "BaZ"]) -def env_strict_var(request): - assert os.getenv(validate.ROMAN_STRICT_VALIDATION) == "true" - os.environ[validate.ROMAN_STRICT_VALIDATION] = request.param - yield request.param - os.environ[validate.ROMAN_STRICT_VALIDATION] = "true" - - -def test_will_strict_validate(env_strict_var): +@pytest.mark.parametrize("nuke_env_strict_var", VALIDATION_CASES, indirect=True) +def test_will_strict_validate(nuke_env_strict_var): # Test the fixture passed the value of the environment variable - assert os.getenv(validate.ROMAN_STRICT_VALIDATION) == env_strict_var + assert os.getenv(validate.ROMAN_STRICT_VALIDATION) == nuke_env_strict_var # Test the validate property - truth = env_strict_var.lower() in ["true", "yes", "1"] + truth = nuke_env_strict_var.lower() in ["true", "yes", "1"] context = nullcontext() if truth else pytest.warns(validate.ValidationWarning) with context: assert validate.will_strict_validate() is truth # Try all uppercase - os.environ[validate.ROMAN_STRICT_VALIDATION] = env_strict_var.upper() + os.environ[validate.ROMAN_STRICT_VALIDATION] = nuke_env_strict_var.upper() with context: assert validate.will_strict_validate() is truth # Try all lowercase - os.environ[validate.ROMAN_STRICT_VALIDATION] = env_strict_var.lower() + os.environ[validate.ROMAN_STRICT_VALIDATION] = nuke_env_strict_var.lower() with context: assert validate.will_strict_validate() is truth