Skip to content

Commit

Permalink
docstrings up to load_meds_data
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksiakolo committed Jun 12, 2024
1 parent 5854556 commit 846a604
Showing 1 changed file with 34 additions and 25 deletions.
59 changes: 34 additions & 25 deletions src/MEDS_tabular_automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,16 @@ def get_ts_feature_cols(shard_df: pl.LazyFrame) -> list[str]:
def get_prediction_ts_cols(
aggregations: list[str], ts_feature_cols: pl.LazyFrame, window_sizes: list[str] | None = None
) -> list[str]:
"""Generates a list of feature column names that will be used for downstream task."""
"""Generates a list of feature column names for prediction tasks based on aggregations and window sizes.
Args:
aggregations: The list of aggregation methods to apply.
ts_feature_cols: The list of existing time-series feature columns.
window_sizes: The optional list of window sizes to consider.
Returns:
A list of feature column names formatted with aggregation and window size.
"""
agg_feature_columns = []
for code in ts_feature_cols:
ts_aggregations = [f"{code}/{agg}" for agg in aggregations]
Expand All @@ -354,47 +363,47 @@ def get_prediction_ts_cols(


def get_flat_rep_feature_cols(cfg: DictConfig, shard_df: pl.LazyFrame) -> list[str]:
"""Generates a list of feature column names from the data within each shard based on specified
configurations.
"""Combines static and time-series feature columns from a shard based on specified configurations.
Parameters:
- cfg (dict): Configuration dictionary specifying how features should be evaluated and aggregated.
- shard_df (pl.LazyFrame): MEDS format dataframe shard.
This function evaluates the properties of codes within training data and applies configured
aggregations to generate a comprehensive list of all feature columns for modeling purposes.
Returns:
- list[str]: list of all feature columns.
Args:
cfg: The configuration dictionary specifying aggregation settings.
shard_df: The LazyFrame shard in MEDS format to process.
This function evaluates the properties of codes within training data and applies configured
aggregations to generate a comprehensive list of feature columns for modeling purposes.
Returns:
A combined list of all feature columns from both static and time-series data.
"""
static_feature_columns = get_static_feature_cols(shard_df)
ts_feature_columns = get_ts_feature_cols(cfg.aggs, shard_df)
return static_feature_columns + ts_feature_columns


def load_meds_data(MEDS_cohort_dir: str, load_data: bool = True) -> Mapping[str, pl.DataFrame]:
"""Loads the MEDS dataset from disk.
def load_meds_data(MEDS_cohort_dir: str, load_data: bool = True) -> Mapping[str, pl.LazyFrame]:
"""Loads the MEDS dataset from disk, structured by data splits.
Args:
MEDS_cohort_dir: The directory containing the MEDS datasets split by subfolders.
We expect `train` to be a split so `MEDS_cohort_dir/train` should exist.
load_data: If True, returns LazyFrames for each data split, otherwise returns file paths.
Returns:
Mapping[str, pl.DataFrame]: Mapping from split name to a polars DataFrame containing the MEDS dataset.
A dictionary mapping from split name to a LazyFrame, containing the MEDS dataset for each split.
Example:
>>> import tempfile
>>> from pathlib import Path
>>> MEDS_cohort_dir = Path(tempfile.mkdtemp())
>>> for split in ["train", "val", "test"]:
... split_dir = MEDS_cohort_dir / split
... split_dir.mkdir()
... pl.DataFrame({"patient_id": [1, 2, 3]}).write_parquet(split_dir / "data.parquet")
>>> split_to_df = load_meds_data(MEDS_cohort_dir)
>>> assert "train" in split_to_df
>>> assert len(split_to_df) == 3
>>> assert len(split_to_df["train"]) == 1
>>> assert isinstance(split_to_df["train"][0], pl.LazyFrame)
>>> import tempfile
>>> from pathlib import Path
>>> MEDS_cohort_dir = Path(tempfile.mkdtemp())
>>> for split in ["train", "val", "test"]:
... split_dir = MEDS_cohort_dir / split
... split_dir.mkdir()
... pl.DataFrame({"patient_id": [1, 2, 3]}).write_parquet(split_dir / "data.parquet")
>>> split_to_df = load_meds_data(MEDS_cohort_dir)
>>> assert "train" in split_to_df
>>> assert len(split_to_df) == 3
>>> assert len(split_to_df["train"]) == 1
>>> assert isinstance(split_to_df["train"][0], pl.LazyFrame)
"""
MEDS_cohort_dir = Path(MEDS_cohort_dir)
meds_fps = list(MEDS_cohort_dir.glob("*/*.parquet"))
Expand Down

0 comments on commit 846a604

Please sign in to comment.