diff --git a/CHANGES.rst b/CHANGES.rst index 88d87dfa..3d2b36c2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,8 @@ - Allow assignment to or creation of node attributes using dot notation of object instances with validation. [#284] +- Bugfix for ``model.meta.filename`` not matching the filename of the file on disk. [#295] + 0.18.0 (2023-11-06) =================== diff --git a/src/roman_datamodels/datamodels/_core.py b/src/roman_datamodels/datamodels/_core.py index 53e081e8..0cd27761 100644 --- a/src/roman_datamodels/datamodels/_core.py +++ b/src/roman_datamodels/datamodels/_core.py @@ -11,10 +11,9 @@ import copy import datetime import functools -import os -import os.path import sys -from pathlib import PurePath +from contextlib import contextmanager +from pathlib import Path, PurePath import asdf import numpy as np @@ -48,6 +47,26 @@ def wrapper(self, *args, **kwargs): return wrapper +@contextmanager +def _temporary_update_filename(datamodel, filename): + """ + Context manager to temporarily update the filename of a datamodel so that it + can be saved with that new file name without changing the current model's filename + """ + from roman_datamodels.stnode import Filename + + if "meta" in datamodel._instance and "filename" in datamodel._instance.meta: + old_filename = datamodel._instance.meta.filename + datamodel._instance.meta.filename = Filename(filename) + + yield + datamodel._instance.meta.filename = old_filename + return + + yield + return + + class DataModel(abc.ABC): """Base class for all top level datamodels""" @@ -181,17 +200,9 @@ def clone(target, source, deepcopy=False, memo=None): target._ctx = target def save(self, path, dir_path=None, *args, **kwargs): - if callable(path): - path_head, path_tail = os.path.split(path(self.meta.filename)) - else: - path_head, path_tail = os.path.split(path) - base, ext = os.path.splitext(path_tail) - if isinstance(ext, bytes): - ext = ext.decode(sys.getfilesystemencoding()) - - if dir_path: - path_head = dir_path - output_path = os.path.join(path_head, path_tail) + path = Path(path(self.meta.filename) if callable(path) else path) + output_path = Path(dir_path) / path.name if dir_path else path + ext = path.suffix.decode(sys.getfilesystemencoding()) if isinstance(path.suffix, bytes) else path.suffix # TODO: Support gzip-compressed fits if ext == ".asdf": @@ -206,10 +217,10 @@ def open_asdf(self, init=None, **kwargs): return asdf.open(init, **kwargs) if isinstance(init, str) else asdf.AsdfFile(init, **kwargs) def to_asdf(self, init, *args, **kwargs): - with validate.nuke_validation(): - asdffile = self.open_asdf(**kwargs) - asdffile.tree = {"roman": self._instance} - asdffile.write_to(init, *args, **kwargs) + with validate.nuke_validation(), _temporary_update_filename(self, Path(init).name): + asdf_file = self.open_asdf(**kwargs) + asdf_file.tree = {"roman": self._instance} + asdf_file.write_to(init, *args, **kwargs) def get_primary_array_name(self): """ diff --git a/tests/test_models.py b/tests/test_models.py index 766b0231..b27661d7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -804,3 +804,15 @@ def test_datamodel_construct_like_from_like(model): new_mdl = model(mdl) assert new_mdl is mdl assert new_mdl._iscopy == "foo" # Verify that the constructor didn't override stuff + + +def test_datamodel_save_filename(tmp_path): + filename = tmp_path / "fancy_filename.asdf" + ramp = utils.mk_datamodel(datamodels.RampModel, shape=(2, 8, 8)) + assert ramp.meta.filename != filename.name + + ramp.save(filename) + assert ramp.meta.filename != filename.name + + with datamodels.open(filename) as new_ramp: + assert new_ramp.meta.filename == filename.name