Skip to content

Commit

Permalink
Merge pull request #13 from mmcdermott/utils
Browse files Browse the repository at this point in the history
Adding docstrings to src/MEDS_tabular_automl/utils.py
  • Loading branch information
mmcdermott authored Jun 13, 2024
2 parents d8e9de6 + 3f4b87c commit 1a3cfe2
Show file tree
Hide file tree
Showing 3 changed files with 405 additions and 176 deletions.
178 changes: 111 additions & 67 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,40 @@
from MEDS_tabular_automl.utils import DF_T, get_feature_names


def convert_to_df(freq_dict):
def convert_to_df(freq_dict: dict[str, int]) -> pl.DataFrame:
"""Converts a dictionary of code frequencies to a Polars DataFrame.
Args:
freq_dict: A dictionary with code features and their respective frequencies.
Returns:
A DataFrame with two columns, "code" and "count".
"""
return pl.DataFrame([[col, freq] for col, freq in freq_dict.items()], schema=["code", "count"])


def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> list[str]:
"""Generates a list of feature column names from the data within each shard based on specified
configurations.
def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> pl.DataFrame:
"""Generates a DataFrame containing the frequencies of codes and numerical values under different
aggregations by computing frequency counts for certain attributes and organizing the results into specific
categories based on the dataset's features.
Parameters:
- cfg (DictConfig): Configuration dictionary specifying how features should be evaluated and aggregated.
- split_to_shard_df (dict): A dictionary of DataFrames, divided by data split (e.g., 'train', 'test').
Args:
cfg: Configuration dictionary specifying how features should be evaluated and aggregated.
shard_df: A DataFrame containing the data to be analyzed and split (e.g., 'train', 'test').
Returns:
- tuple[list[str], dict]: A tuple containing a list of feature columns and a dictionary of code properties
identified during the evaluation.
A tuple containing a list of feature columns and a dictionary of code properties identified
during the evaluation.
This function evaluates the properties of codes within training data and applies configured
aggregations to generate a comprehensive list of feature columns for modeling purposes.
Examples:
# >>> import polars as pl
# >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
# ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None],
# ... 'numerical_value': [1, None, 2, 2, None, None, 3]}
# >>> df = pl.DataFrame(data).lazy()
# >>> aggs = ['value/sum', 'code/count']
# >>> compute_feature_frequencies(aggs, df)
# ['A/code', 'A/value', 'C/code', 'C/value']
# >>> import polars as pl
# >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
# ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None],
# ... 'numerical_value': [1, None, 2, 2, None, None, 3]}
# >>> df = pl.DataFrame(data).lazy()
# >>> aggs = ['value/sum', 'code/count']
# >>> compute_feature_frequencies(aggs, df)
# ['A/code', 'A/value', 'C/code', 'C/value']
"""
static_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null()
Expand Down Expand Up @@ -64,48 +71,73 @@ def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> list[str]:
return convert_to_df(combined_freqs)


def convert_to_freq_dict(df: pl.LazyFrame) -> dict:
def convert_to_freq_dict(df: pl.LazyFrame) -> dict[str, dict[int, int]]:
"""Converts a DataFrame to a dictionary of frequencies.
This function converts a DataFrame to a dictionary of frequencies, where the keys are the
column names and the values are dictionaries of code frequencies.
Args:
- df (pl.DataFrame): The DataFrame to be converted.
df: The DataFrame to be converted.
Returns:
- dict: A dictionary of frequencies, where the keys are the column names and the values are
dictionaries of code frequencies.
A dictionary where keys are column names and values are
dictionaries of code frequencies.
Raises:
ValueError: If the DataFrame does not have the expected columns "code" and "count".
Example:
# >>> import polars as pl
# >>> df = pl.DataFrame({
# ... "code": [1, 2, 3, 4, 5],
# ... "value": [10, 20, 30, 40, 50]
# ... })
# >>> convert_to_freq_dict(df)
# {'code': {1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, 'value': {10: 1, 20: 1, 30: 1, 40: 1, 50: 1}}
# >>> import polars as pl
# >>> df = pl.DataFrame({
# ... "code": [1, 2, 3, 4, 5],
# ... "value": [10, 20, 30, 40, 50]
# ... })
# >>> convert_to_freq_dict(df)
# {'code': {1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, 'value': {10: 1, 20: 1, 30: 1, 40: 1, 50: 1}}
"""
if not df.columns == ["code", "count"]:
raise ValueError(f"DataFrame must have columns 'code' and 'count', but has columns {df.columns}!")
return dict(df.collect().iter_rows())


def get_feature_columns(fp):
def get_feature_columns(fp: Path) -> list[str]:
"""Retrieves feature column names from a parquet file.
Args:
fp: File path to the Parquet data.
Returns:
Sorted list of column names.
"""
return sorted(list(convert_to_freq_dict(pl.scan_parquet(fp)).keys()))


def get_feature_freqs(fp):
def get_feature_freqs(fp: Path) -> dict[str, int]:
"""Retrieves feature frequencies from a parquet file.
Args:
fp: File path to the Parquet data.
Returns:
Dictionary of feature frequencies.
"""
return convert_to_freq_dict(pl.scan_parquet(fp))


def filter_to_codes(
allowed_codes: list[str] | None,
min_code_inclusion_frequency: int,
code_metadata_fp: Path,
):
"""Returns intersection of allowed codes if they are specified, and filters to codes based on inclusion
frequency."""
) -> list[str]:
"""Filters and returns codes based on allowed list and minimum frequency.
Args:
allowed_codes: List of allowed codes, None means all codes are allowed.
min_code_inclusion_frequency: Minimum frequency a code must have to be included.
code_metadata_fp: Path to the metadata file containing code information.
Returns:
Sorted list of the intersection of allowed codes (if they are specified) and filters based on
inclusion frequency.
"""
if allowed_codes is None:
allowed_codes = get_feature_columns(code_metadata_fp)
feature_freqs = get_feature_freqs(code_metadata_fp)
Expand All @@ -129,7 +161,15 @@ def filter_to_codes(
# OmegaConf.register_new_resolver("filter_to_codes", filter_to_codes)


def clear_code_aggregation_suffix(code):
def clear_code_aggregation_suffix(code: str) -> str:
"""Removes aggregation suffixes from code strings.
Args:
code: Code string to be cleared.
Returns:
Code string without aggregation suffixes.
"""
if code.endswith("/code"):
return code[:-5]
elif code.endswith("/value"):
Expand All @@ -140,36 +180,40 @@ def clear_code_aggregation_suffix(code):
return code[:-13]


def filter_parquet(fp, allowed_codes: list[str]):
"""Loads Parquet with Polars and filters to allowed codes.
def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
"""Loads and filters a Parquet file with Polars to include only specified codes and removes rare
codes/values.
Args:
fp: Path to the Meds cohort shard
allowed_codes: List of codes to filter to.
Expect:
>>> from tempfile import NamedTemporaryFile
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
│ code ┆ timestamp ┆ numerical_value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═════════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴─────────────────┘
>>> fp.close()
fp: Path to the Parquet file of a Meds cohort shard.
allowed_codes: List of codes to filter by.
Returns:
pl.LazyFrame: A filtered LazyFrame containing only the allowed and not rare codes/values.
Examples:
>>> from tempfile import NamedTemporaryFile
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
│ code ┆ timestamp ┆ numerical_value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═════════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴─────────────────┘
>>> fp.close()
"""
df = pl.scan_parquet(fp)
# Drop values that are rare
Expand Down
98 changes: 90 additions & 8 deletions src/MEDS_tabular_automl/file_name.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,110 @@
"""Help functions for getting file names and paths for MEDS tabular automl tasks."""
"""Helper functions for getting file names and paths for MEDS tabular automl tasks."""
from pathlib import Path

from omegaconf import DictConfig

def list_subdir_files(dir: [Path | str], fmt: str):
return sorted(list(Path(dir).glob(f"**/*.{fmt}")))

def list_subdir_files(root: Path | str, ext: str) -> list[Path]:
"""List files in subdirectories of a directory with a given extension.
def get_task_specific_path(cfg, split, shard_num, window_size, agg):
return Path(cfg.input_dir) / split / f"{shard_num}" / f"{window_size}" / f"{agg}.npz"
Args:
root: Path to the directory.
ext: File extension to filter files.
Returns:
An alphabetically sorted list of Path objects to files matching the extension in any level of
subdirectories of the given directory.
def get_model_files(cfg, split: str, shard_num: int):
Examples:
>>> import tempfile
>>> tmpdir = tempfile.TemporaryDirectory()
>>> root = Path(tmpdir.name)
>>> subdir_1 = root / "subdir_1"
>>> subdir_1.mkdir()
>>> subdir_2 = root / "subdir_2"
>>> subdir_2.mkdir()
>>> subdir_1_A = subdir_1 / "A"
>>> subdir_1_A.mkdir()
>>> (root / "1.csv").touch()
>>> (root / "foo.parquet").touch()
>>> (root / "2.csv").touch()
>>> (root / "subdir_1" / "3.csv").touch()
>>> (root / "subdir_2" / "4.csv").touch()
>>> (root / "subdir_1" / "A" / "5.csv").touch()
>>> (root / "subdir_1" / "A" / "15.csv.gz").touch()
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "csv")] # doctest: +NORMALIZE_WHITESPACE
[PosixPath('1.csv'),
PosixPath('2.csv'),
PosixPath('subdir_1/3.csv'),
PosixPath('subdir_1/A/5.csv'),
PosixPath('subdir_2/4.csv')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "parquet")]
[PosixPath('foo.parquet')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "csv.gz")]
[PosixPath('subdir_1/A/15.csv.gz')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "json")]
[]
>>> list_subdir_files(root / "nonexistent", "csv")
[]
>>> tmpdir.cleanup()
"""

return sorted(list(Path(root).glob(f"**/*.{ext}")))


def get_model_files(cfg: DictConfig, split: str, shard: str) -> list[Path]:
"""Get the tabularized npz files for a given split and shard number.
TODO: Rename function to get_tabularized_input_files or something.
Args:
cfg: `OmegaConf.DictConfig` object with the configuration. It must have the following keys:
- input_dir: Path to the directory with the tabularized npz files.
- tabularization: Tabularization configuration, as a nested `DictConfig` object with keys:
- window_sizes: List of window sizes.
- aggs: List of aggregation functions.
split: Split name to reference the files stored on disk.
shard: The shard within the split to reference the files stored on disk.
Returns:
An alphabetically sorted list of Path objects to the tabularized npz files for the given split and
shard. These files will take the form ``{cfg.input_dir}/{split}/{shard}/{window_size}/{agg}.npz``. For
static aggregations, the window size will be "none" as these features are not time-varying.
Examples:
>>> cfg = DictConfig({
... "input_dir": "data",
... "tabularization": {
... "window_sizes": ["1d", "7d"],
... "aggs": ["code/count", "value/sum", "static/present"],
... }
... })
>>> get_model_files(cfg, "train", "0") # doctest: +NORMALIZE_WHITESPACE
[PosixPath('data/train/0/1d/code/count.npz'),
PosixPath('data/train/0/1d/value/sum.npz'),
PosixPath('data/train/0/7d/code/count.npz'),
PosixPath('data/train/0/7d/value/sum.npz'),
PosixPath('data/train/0/none/static/present.npz')]
>>> get_model_files(cfg, "test/IID", "3/0") # doctest: +NORMALIZE_WHITESPACE
[PosixPath('data/test/IID/3/0/1d/code/count.npz'),
PosixPath('data/test/IID/3/0/1d/value/sum.npz'),
PosixPath('data/test/IID/3/0/7d/code/count.npz'),
PosixPath('data/test/IID/3/0/7d/value/sum.npz'),
PosixPath('data/test/IID/3/0/none/static/present.npz')]
"""
window_sizes = cfg.tabularization.window_sizes
aggs = cfg.tabularization.aggs
shard_dir = Path(cfg.input_dir) / split / shard
# Given a shard number, returns the model files
model_files = []
for window_size in window_sizes:
for agg in aggs:
if agg.startswith("static"):
continue
else:
model_files.append(get_task_specific_path(cfg, split, shard_num, window_size, agg))
model_files.append(shard_dir / window_size / f"{agg}.npz")
for agg in aggs:
if agg.startswith("static"):
window_size = "none"
model_files.append(get_task_specific_path(cfg, split, shard_num, window_size, agg))
model_files.append(shard_dir / window_size / f"{agg}.npz")
return sorted(model_files)
Loading

0 comments on commit 1a3cfe2

Please sign in to comment.