Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding docstrings to src/MEDS_tabular_automl/utils.py #13

Merged
merged 21 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
3d6fa43
Added docstring for first function
mmcdermott Jun 12, 2024
ad7ced3
Added doctests for list_subdir_files
mmcdermott Jun 12, 2024
92d2a2c
Removed rarely used and unnecessary function
mmcdermott Jun 12, 2024
fe8971d
Cleaned up and added doctests for get_model_files
mmcdermott Jun 12, 2024
b2883a8
Merge pull request #8 from mmcdermott/file_name_docs
mmcdermott Jun 12, 2024
4f63c37
Merge branch 'main' into improve_docstrings_and_tests
mmcdermott Jun 12, 2024
48e5fb7
docstring for first function of describe_codes
aleksiakolo Jun 12, 2024
f2f7a1f
Updated docstring of second funcion
aleksiakolo Jun 12, 2024
4d44305
Updated docstring of third funcion
aleksiakolo Jun 12, 2024
361a066
Updated docstring of fourth funcion
aleksiakolo Jun 12, 2024
47713cd
Updated docstring of last funcion
aleksiakolo Jun 12, 2024
c3da59c
Merge pull request #10 from mmcdermott/describe_codes
mmcdermott Jun 12, 2024
d1dc2c7
Merge branch 'main' into improve_docstrings_and_tests
mmcdermott Jun 12, 2024
8d3c5f9
Merge branch 'improve_docstrings_and_tests' of github.com:mmcdermott/…
mmcdermott Jun 12, 2024
97b6160
Added docstrings for src/MEDS_tabular_automl/utils.py
aleksiakolo Jun 12, 2024
169efa9
first eight docstrings
aleksiakolo Jun 12, 2024
5854556
docstrings for getting feature columns functions
aleksiakolo Jun 12, 2024
846a604
docstrings up to load_meds_data
aleksiakolo Jun 12, 2024
d7d72b1
finished utils docstrings
aleksiakolo Jun 12, 2024
4afccdf
added DF_T
aleksiakolo Jun 13, 2024
3f4b87c
added previous DF_T top file comments
aleksiakolo Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 111 additions & 67 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,40 @@
from MEDS_tabular_automl.utils import DF_T, get_feature_names


def convert_to_df(freq_dict):
def convert_to_df(freq_dict: dict[str, int]) -> pl.DataFrame:
"""Converts a dictionary of code frequencies to a Polars DataFrame.

Args:
freq_dict: A dictionary with code features and their respective frequencies.

Returns:
A DataFrame with two columns, "code" and "count".
"""
return pl.DataFrame([[col, freq] for col, freq in freq_dict.items()], schema=["code", "count"])


def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> list[str]:
"""Generates a list of feature column names from the data within each shard based on specified
configurations.
def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> pl.DataFrame:
"""Generates a DataFrame containing the frequencies of codes and numerical values under different
aggregations by computing frequency counts for certain attributes and organizing the results into specific
categories based on the dataset's features.

Parameters:
- cfg (DictConfig): Configuration dictionary specifying how features should be evaluated and aggregated.
- split_to_shard_df (dict): A dictionary of DataFrames, divided by data split (e.g., 'train', 'test').
Args:
cfg: Configuration dictionary specifying how features should be evaluated and aggregated.
shard_df: A DataFrame containing the data to be analyzed and split (e.g., 'train', 'test').

Returns:
- tuple[list[str], dict]: A tuple containing a list of feature columns and a dictionary of code properties
identified during the evaluation.
A tuple containing a list of feature columns and a dictionary of code properties identified
during the evaluation.

This function evaluates the properties of codes within training data and applies configured
aggregations to generate a comprehensive list of feature columns for modeling purposes.
Examples:
# >>> import polars as pl
# >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
# ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None],
# ... 'numerical_value': [1, None, 2, 2, None, None, 3]}
# >>> df = pl.DataFrame(data).lazy()
# >>> aggs = ['value/sum', 'code/count']
# >>> compute_feature_frequencies(aggs, df)
# ['A/code', 'A/value', 'C/code', 'C/value']
# >>> import polars as pl
# >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
# ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None],
# ... 'numerical_value': [1, None, 2, 2, None, None, 3]}
# >>> df = pl.DataFrame(data).lazy()
# >>> aggs = ['value/sum', 'code/count']
# >>> compute_feature_frequencies(aggs, df)
# ['A/code', 'A/value', 'C/code', 'C/value']
"""
static_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null()
Expand Down Expand Up @@ -64,48 +71,73 @@ def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> list[str]:
return convert_to_df(combined_freqs)


def convert_to_freq_dict(df: pl.LazyFrame) -> dict:
def convert_to_freq_dict(df: pl.LazyFrame) -> dict[str, dict[int, int]]:
"""Converts a DataFrame to a dictionary of frequencies.

This function converts a DataFrame to a dictionary of frequencies, where the keys are the
column names and the values are dictionaries of code frequencies.

Args:
- df (pl.DataFrame): The DataFrame to be converted.
df: The DataFrame to be converted.

Returns:
- dict: A dictionary of frequencies, where the keys are the column names and the values are
dictionaries of code frequencies.
A dictionary where keys are column names and values are
dictionaries of code frequencies.

Raises:
ValueError: If the DataFrame does not have the expected columns "code" and "count".

Example:
# >>> import polars as pl
# >>> df = pl.DataFrame({
# ... "code": [1, 2, 3, 4, 5],
# ... "value": [10, 20, 30, 40, 50]
# ... })
# >>> convert_to_freq_dict(df)
# {'code': {1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, 'value': {10: 1, 20: 1, 30: 1, 40: 1, 50: 1}}
# >>> import polars as pl
# >>> df = pl.DataFrame({
# ... "code": [1, 2, 3, 4, 5],
# ... "value": [10, 20, 30, 40, 50]
# ... })
# >>> convert_to_freq_dict(df)
# {'code': {1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, 'value': {10: 1, 20: 1, 30: 1, 40: 1, 50: 1}}
"""
if not df.columns == ["code", "count"]:
raise ValueError(f"DataFrame must have columns 'code' and 'count', but has columns {df.columns}!")
return dict(df.collect().iter_rows())


def get_feature_columns(fp):
def get_feature_columns(fp: Path) -> list[str]:
"""Retrieves feature column names from a parquet file.

Args:
fp: File path to the Parquet data.

Returns:
Sorted list of column names.
"""
return sorted(list(convert_to_freq_dict(pl.scan_parquet(fp)).keys()))


def get_feature_freqs(fp):
def get_feature_freqs(fp: Path) -> dict[str, int]:
"""Retrieves feature frequencies from a parquet file.

Args:
fp: File path to the Parquet data.

Returns:
Dictionary of feature frequencies.
"""
return convert_to_freq_dict(pl.scan_parquet(fp))


def filter_to_codes(
allowed_codes: list[str] | None,
min_code_inclusion_frequency: int,
code_metadata_fp: Path,
):
"""Returns intersection of allowed codes if they are specified, and filters to codes based on inclusion
frequency."""
) -> list[str]:
"""Filters and returns codes based on allowed list and minimum frequency.

Args:
allowed_codes: List of allowed codes, None means all codes are allowed.
min_code_inclusion_frequency: Minimum frequency a code must have to be included.
code_metadata_fp: Path to the metadata file containing code information.

Returns:
Sorted list of the intersection of allowed codes (if they are specified) and filters based on
inclusion frequency.
"""
if allowed_codes is None:
allowed_codes = get_feature_columns(code_metadata_fp)
feature_freqs = get_feature_freqs(code_metadata_fp)
Expand All @@ -129,7 +161,15 @@ def filter_to_codes(
# OmegaConf.register_new_resolver("filter_to_codes", filter_to_codes)


def clear_code_aggregation_suffix(code):
def clear_code_aggregation_suffix(code: str) -> str:
"""Removes aggregation suffixes from code strings.

Args:
code: Code string to be cleared.

Returns:
Code string without aggregation suffixes.
"""
if code.endswith("/code"):
return code[:-5]
elif code.endswith("/value"):
Expand All @@ -140,36 +180,40 @@ def clear_code_aggregation_suffix(code):
return code[:-13]


def filter_parquet(fp, allowed_codes: list[str]):
"""Loads Parquet with Polars and filters to allowed codes.
def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
"""Loads and filters a Parquet file with Polars to include only specified codes and removes rare
codes/values.

Args:
fp: Path to the Meds cohort shard
allowed_codes: List of codes to filter to.

Expect:
>>> from tempfile import NamedTemporaryFile
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
│ code ┆ timestamp ┆ numerical_value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═════════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴─────────────────┘
>>> fp.close()
fp: Path to the Parquet file of a Meds cohort shard.
allowed_codes: List of codes to filter by.

Returns:
pl.LazyFrame: A filtered LazyFrame containing only the allowed and not rare codes/values.

Examples:
>>> from tempfile import NamedTemporaryFile
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
│ code ┆ timestamp ┆ numerical_value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═════════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴─────────────────┘
>>> fp.close()
"""
df = pl.scan_parquet(fp)
# Drop values that are rare
Expand Down
98 changes: 90 additions & 8 deletions src/MEDS_tabular_automl/file_name.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,110 @@
"""Help functions for getting file names and paths for MEDS tabular automl tasks."""
"""Helper functions for getting file names and paths for MEDS tabular automl tasks."""
from pathlib import Path

from omegaconf import DictConfig

def list_subdir_files(dir: [Path | str], fmt: str):
return sorted(list(Path(dir).glob(f"**/*.{fmt}")))

def list_subdir_files(root: Path | str, ext: str) -> list[Path]:
"""List files in subdirectories of a directory with a given extension.

def get_task_specific_path(cfg, split, shard_num, window_size, agg):
return Path(cfg.input_dir) / split / f"{shard_num}" / f"{window_size}" / f"{agg}.npz"
Args:
root: Path to the directory.
ext: File extension to filter files.

Returns:
An alphabetically sorted list of Path objects to files matching the extension in any level of
subdirectories of the given directory.

def get_model_files(cfg, split: str, shard_num: int):
Examples:
>>> import tempfile
>>> tmpdir = tempfile.TemporaryDirectory()
>>> root = Path(tmpdir.name)
>>> subdir_1 = root / "subdir_1"
>>> subdir_1.mkdir()
>>> subdir_2 = root / "subdir_2"
>>> subdir_2.mkdir()
>>> subdir_1_A = subdir_1 / "A"
>>> subdir_1_A.mkdir()
>>> (root / "1.csv").touch()
>>> (root / "foo.parquet").touch()
>>> (root / "2.csv").touch()
>>> (root / "subdir_1" / "3.csv").touch()
>>> (root / "subdir_2" / "4.csv").touch()
>>> (root / "subdir_1" / "A" / "5.csv").touch()
>>> (root / "subdir_1" / "A" / "15.csv.gz").touch()
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "csv")] # doctest: +NORMALIZE_WHITESPACE
[PosixPath('1.csv'),
PosixPath('2.csv'),
PosixPath('subdir_1/3.csv'),
PosixPath('subdir_1/A/5.csv'),
PosixPath('subdir_2/4.csv')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "parquet")]
[PosixPath('foo.parquet')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "csv.gz")]
[PosixPath('subdir_1/A/15.csv.gz')]
>>> [fp.relative_to(root) for fp in list_subdir_files(root, "json")]
[]
>>> list_subdir_files(root / "nonexistent", "csv")
[]
>>> tmpdir.cleanup()
"""

return sorted(list(Path(root).glob(f"**/*.{ext}")))


def get_model_files(cfg: DictConfig, split: str, shard: str) -> list[Path]:
"""Get the tabularized npz files for a given split and shard number.

TODO: Rename function to get_tabularized_input_files or something.

Args:
cfg: `OmegaConf.DictConfig` object with the configuration. It must have the following keys:
- input_dir: Path to the directory with the tabularized npz files.
- tabularization: Tabularization configuration, as a nested `DictConfig` object with keys:
- window_sizes: List of window sizes.
- aggs: List of aggregation functions.
split: Split name to reference the files stored on disk.
shard: The shard within the split to reference the files stored on disk.

Returns:
An alphabetically sorted list of Path objects to the tabularized npz files for the given split and
shard. These files will take the form ``{cfg.input_dir}/{split}/{shard}/{window_size}/{agg}.npz``. For
static aggregations, the window size will be "none" as these features are not time-varying.

Examples:
>>> cfg = DictConfig({
... "input_dir": "data",
... "tabularization": {
... "window_sizes": ["1d", "7d"],
... "aggs": ["code/count", "value/sum", "static/present"],
... }
... })
>>> get_model_files(cfg, "train", "0") # doctest: +NORMALIZE_WHITESPACE
[PosixPath('data/train/0/1d/code/count.npz'),
PosixPath('data/train/0/1d/value/sum.npz'),
PosixPath('data/train/0/7d/code/count.npz'),
PosixPath('data/train/0/7d/value/sum.npz'),
PosixPath('data/train/0/none/static/present.npz')]
>>> get_model_files(cfg, "test/IID", "3/0") # doctest: +NORMALIZE_WHITESPACE
[PosixPath('data/test/IID/3/0/1d/code/count.npz'),
PosixPath('data/test/IID/3/0/1d/value/sum.npz'),
PosixPath('data/test/IID/3/0/7d/code/count.npz'),
PosixPath('data/test/IID/3/0/7d/value/sum.npz'),
PosixPath('data/test/IID/3/0/none/static/present.npz')]
"""
window_sizes = cfg.tabularization.window_sizes
aggs = cfg.tabularization.aggs
shard_dir = Path(cfg.input_dir) / split / shard
# Given a shard number, returns the model files
model_files = []
for window_size in window_sizes:
for agg in aggs:
if agg.startswith("static"):
continue
else:
model_files.append(get_task_specific_path(cfg, split, shard_num, window_size, agg))
model_files.append(shard_dir / window_size / f"{agg}.npz")
for agg in aggs:
if agg.startswith("static"):
window_size = "none"
model_files.append(get_task_specific_path(cfg, split, shard_num, window_size, agg))
model_files.append(shard_dir / window_size / f"{agg}.npz")
return sorted(model_files)
Loading
Loading