Skip to content

Commit

Permalink
Merge pull request #8 from mmcdermott/file_name_docs
Browse files Browse the repository at this point in the history
Added docstrings and tests for `file_name.py`
  • Loading branch information
mmcdermott authored Jun 12, 2024
2 parents 9810bac + fe8971d commit b2883a8
Showing 1 changed file with 90 additions and 8 deletions.
98 changes: 90 additions & 8 deletions src/MEDS_tabular_automl/file_name.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,110 @@
"""Help functions for getting file names and paths for MEDS tabular automl tasks."""
"""Helper functions for getting file names and paths for MEDS tabular automl tasks."""
from pathlib import Path

from omegaconf import DictConfig

def list_subdir_files(dir: [Path | str], fmt: str):
return sorted(list(Path(dir).glob(f"**/*.{fmt}")))

def list_subdir_files(root: Path | str, ext: str) -> list[Path]:
"""List files in subdirectories of a directory with a given extension.
def get_task_specific_path(cfg, split, shard_num, window_size, agg):
return Path(cfg.input_dir) / split / f"{shard_num}" / f"{window_size}" / f"{agg}.npz"
Args:
root: Path to the directory.
ext: File extension to filter files.
Returns:
An alphabetically sorted list of Path objects to files matching the extension in any level of
subdirectories of the given directory.
def get_model_files(cfg, split: str, shard_num: int):
Examples:
>>> import tempfile
>>> tmpdir = tempfile.TemporaryDirectory()
>>> root = Path(tmpdir.name)
>>> subdir_1 = root / "subdir_1"
>>> subdir_1.mkdir()
>>> subdir_2 = root / "subdir_2"
>>> subdir_2.mkdir()
>>> subdir_1_A = subdir_1 / "A"
>>> subdir_1_A.mkdir()
>>> (root / "1.csv").touch()
>>> (root / "foo.parquet").touch()
>>> (root / "2.csv").touch()
>>> (root / "subdir_1" / "3.csv").touch()
>>> (root / "subdir_2" / "4.csv").touch()
>>> (root / "subdir_1" / "A" / "5.csv").touch()
>>> (root / "subdir_1" / "A" / "15.csv.gz").touch()
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "csv")] # doctest: +NORMALIZE_WHITESPACE
[PosixPath('1.csv'),
PosixPath('2.csv'),
PosixPath('subdir_1/3.csv'),
PosixPath('subdir_1/A/5.csv'),
PosixPath('subdir_2/4.csv')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "parquet")]
[PosixPath('foo.parquet')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "csv.gz")]
[PosixPath('subdir_1/A/15.csv.gz')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "json")]
[]
>>> list_subdir_files(root / "nonexistent", "csv")
[]
>>> tmpdir.cleanup()
"""

return sorted(list(Path(root).glob(f"**/*.{ext}")))


def get_model_files(cfg: DictConfig, split: str, shard: str) -> list[Path]:
"""Get the tabularized npz files for a given split and shard number.
TODO: Rename function to get_tabularized_input_files or something.
Args:
cfg: `OmegaConf.DictConfig` object with the configuration. It must have the following keys:
- input_dir: Path to the directory with the tabularized npz files.
- tabularization: Tabularization configuration, as a nested `DictConfig` object with keys:
- window_sizes: List of window sizes.
- aggs: List of aggregation functions.
split: Split name to reference the files stored on disk.
shard: The shard within the split to reference the files stored on disk.
Returns:
An alphabetically sorted list of Path objects to the tabularized npz files for the given split and
shard. These files will take the form ``{cfg.input_dir}/{split}/{shard}/{window_size}/{agg}.npz``. For
static aggregations, the window size will be "none" as these features are not time-varying.
Examples:
>>> cfg = DictConfig({
... "input_dir": "data",
... "tabularization": {
... "window_sizes": ["1d", "7d"],
... "aggs": ["code/count", "value/sum", "static/present"],
... }
... })
>>> get_model_files(cfg, "train", "0") # doctest: +NORMALIZE_WHITESPACE
[PosixPath('data/train/0/1d/code/count.npz'),
PosixPath('data/train/0/1d/value/sum.npz'),
PosixPath('data/train/0/7d/code/count.npz'),
PosixPath('data/train/0/7d/value/sum.npz'),
PosixPath('data/train/0/none/static/present.npz')]
>>> get_model_files(cfg, "test/IID", "3/0") # doctest: +NORMALIZE_WHITESPACE
[PosixPath('data/test/IID/3/0/1d/code/count.npz'),
PosixPath('data/test/IID/3/0/1d/value/sum.npz'),
PosixPath('data/test/IID/3/0/7d/code/count.npz'),
PosixPath('data/test/IID/3/0/7d/value/sum.npz'),
PosixPath('data/test/IID/3/0/none/static/present.npz')]
"""
window_sizes = cfg.tabularization.window_sizes
aggs = cfg.tabularization.aggs
shard_dir = Path(cfg.input_dir) / split / shard
# Given a shard number, returns the model files
model_files = []
for window_size in window_sizes:
for agg in aggs:
if agg.startswith("static"):
continue
else:
model_files.append(get_task_specific_path(cfg, split, shard_num, window_size, agg))
model_files.append(shard_dir / window_size / f"{agg}.npz")
for agg in aggs:
if agg.startswith("static"):
window_size = "none"
model_files.append(get_task_specific_path(cfg, split, shard_num, window_size, agg))
model_files.append(shard_dir / window_size / f"{agg}.npz")
return sorted(model_files)

0 comments on commit b2883a8

Please sign in to comment.