diff --git a/src/imars3d/backend/dataio/data.py b/src/imars3d/backend/dataio/data.py index 5b92cbdf..787ff5f3 100644 --- a/src/imars3d/backend/dataio/data.py +++ b/src/imars3d/backend/dataio/data.py @@ -104,14 +104,14 @@ class load_data(param.ParameterizedFunction): Notes ----- - There are two main signatures to load the data: + There are three main signatures to load the data: 1. load_data(ct_files=ctfs, ob_files=obfs, dc_files=dcfs) 2. load_data(ct_dir=ctdir, ob_dir=obdir, dc_dir=dcdir) + 3. load_data(ct_dir=ctdir, ob_files=obfs, dc_files=dcfs) - The two signatures are mutually exclusive, and dc_files and dc_dir are optional - in both cases as some experiments do not have dark current measurements. + In all signatures dc_files and dc_dir are optional - The fnmatch selectors are applicable in both signature, which help to down-select + The fnmatch selectors are applicable in all signature, which help to down-select files if needed. Default is set to "*", which selects everything. Also, if ob_fnmatch and dc_fnmatch are set to "None" in the second signature call, the data loader will attempt to read the metadata embedded in the first ct file to find obs @@ -157,9 +157,43 @@ def __call__(self, **params): # use set to simplify call signature checking sigs = set([k.split("_")[-1] for k in params.keys() if "fnmatch" not in k]) ref = {"files", "dir"} - if sigs.intersection(ref) == {"files", "dir"}: - logger.error("Files and dir cannot be used at the same time") - raise ValueError("Mix usage of allowed signature.") + + if ("ct_dir" in params.keys()) and ("ob_files" in params.keys()): + logger.debug("Load ct by directory, ob and dc (if any) by files") + ct_dir = params.get("ct_dir") + if not Path(ct_dir).exists(): + logger.error(f"ct_dir {ct_dir} does not exist.") + raise ValueError("ct_dir does not exist.") + else: + ct_dir = Path(ct_dir) + + # gather the ct_files + ct_fnmatch = params.get("ct_fnmatch", "*") + ct_files = ct_dir.glob(ct_fnmatch) + ct_files = list(map(str, ct_files)) + ct_files.sort() + + ob_files = (params.get("ob_files"),) + dc_files = (params.get("dc_files", []),) # it is okay to skip dc + + ob_files = ob_files[0] + dc_files = dc_files[0] + + ct, ob, dc = _load_by_file_list( + ct_files=ct_files, + ob_files=ob_files, + dc_files=dc_files, # it is okay to skip dc + ct_fnmatch=params.get("ct_fnmatch", "*"), # incase None got leaked here + ob_fnmatch=params.get("ob_fnmatch", "*"), + dc_fnmatch=params.get("dc_fnmatch", "*"), + max_workers=self.max_workers, + tqdm_class=params.tqdm_class, + ) + + elif ("ct_files" in params.keys()) and ("ob_dir" in params.keys()): + logger.error("ct_files and ob_dir mixed not allowed!") + raise ValueError("Mix signatures (ct_files, ob_dir) not allowed!") + elif sigs.intersection(ref) == {"files"}: logger.debug("Load by file list") ct, ob, dc = _load_by_file_list( diff --git a/tests/unit/backend/dataio/test_data.py b/tests/unit/backend/dataio/test_data.py index af912989..82cb87b0 100644 --- a/tests/unit/backend/dataio/test_data.py +++ b/tests/unit/backend/dataio/test_data.py @@ -97,6 +97,8 @@ def test_load_data( # error_0: incorrect input argument types with pytest.raises(ValueError): load_data(ct_files=1, ob_files=[], dc_files=[]) + load_data(ct_dir=1, ob_files=[]) + load_data(ct_files=[], ob_dir="/tmp") load_data(ct_files=[], ob_files=[], dc_files=[], ct_fnmatch=1) load_data(ct_files=[], ob_files=[], dc_files=[], ob_fnmatch=1) load_data(ct_files=[], ob_files=[], dc_files=[], dc_fnmatch=1) @@ -106,16 +108,16 @@ def test_load_data( # error_1: out of bounds value with pytest.raises(ValueError): load_data(ct_files=[], ob_files=[], dc_files=[], max_workers=-1) - # error_2: mix usage of function signature 1 and 2 - with pytest.raises(ValueError): - load_data(ct_files=[], ob_files=[], dc_files=[], ct_dir="/tmp", ob_dir="/tmp") # error_3: no valid signature found with pytest.raises(ValueError): load_data(ct_fnmatch=1) - # case_0: load data from file list + # case_0: load ct from directory, ob and dc from files + rst = load_data(ct_dir="/tmp", ob_files=["3", "4"], dc_files=["5", "6"]) + np.testing.assert_almost_equal(np.array(rst).flatten(), np.arange(1, 5, dtype=float)) + # case_1: load data from file list rst = load_data(ct_files=["1", "2"], ob_files=["3", "4"], dc_files=["5", "6"]) np.testing.assert_almost_equal(np.array(rst).flatten(), np.arange(1, 5, dtype=float)) - # case_1: load data from given directory + # case_2: load data from given directory rst = load_data(ct_dir="/tmp", ob_dir="/tmp", dc_dir="/tmp") np.testing.assert_almost_equal(np.array(rst).flatten(), np.arange(1, 5, dtype=float))