Skip to content

Commit

Permalink
Merge branch 'improve_docstrings_and_tests' of github.com:mmcdermott/…
Browse files Browse the repository at this point in the history
…MEDS_Tabular_AutoML into improve_docstrings_and_tests
  • Loading branch information
mmcdermott committed Jun 12, 2024
2 parents d1dc2c7 + c3da59c commit 8d3c5f9
Showing 1 changed file with 111 additions and 67 deletions.
178 changes: 111 additions & 67 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,40 @@
from MEDS_tabular_automl.utils import DF_T, get_feature_names


def convert_to_df(freq_dict):
def convert_to_df(freq_dict: dict[str, int]) -> pl.DataFrame:
"""Converts a dictionary of code frequencies to a Polars DataFrame.
Args:
freq_dict: A dictionary with code features and their respective frequencies.
Returns:
A DataFrame with two columns, "code" and "count".
"""
return pl.DataFrame([[col, freq] for col, freq in freq_dict.items()], schema=["code", "count"])


def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> list[str]:
"""Generates a list of feature column names from the data within each shard based on specified
configurations.
def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> pl.DataFrame:
"""Generates a DataFrame containing the frequencies of codes and numerical values under different
aggregations by computing frequency counts for certain attributes and organizing the results into specific
categories based on the dataset's features.
Parameters:
- cfg (DictConfig): Configuration dictionary specifying how features should be evaluated and aggregated.
- split_to_shard_df (dict): A dictionary of DataFrames, divided by data split (e.g., 'train', 'test').
Args:
cfg: Configuration dictionary specifying how features should be evaluated and aggregated.
shard_df: A DataFrame containing the data to be analyzed and split (e.g., 'train', 'test').
Returns:
- tuple[list[str], dict]: A tuple containing a list of feature columns and a dictionary of code properties
identified during the evaluation.
A tuple containing a list of feature columns and a dictionary of code properties identified
during the evaluation.
This function evaluates the properties of codes within training data and applies configured
aggregations to generate a comprehensive list of feature columns for modeling purposes.
Examples:
# >>> import polars as pl
# >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
# ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None],
# ... 'numerical_value': [1, None, 2, 2, None, None, 3]}
# >>> df = pl.DataFrame(data).lazy()
# >>> aggs = ['value/sum', 'code/count']
# >>> compute_feature_frequencies(aggs, df)
# ['A/code', 'A/value', 'C/code', 'C/value']
# >>> import polars as pl
# >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
# ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None],
# ... 'numerical_value': [1, None, 2, 2, None, None, 3]}
# >>> df = pl.DataFrame(data).lazy()
# >>> aggs = ['value/sum', 'code/count']
# >>> compute_feature_frequencies(aggs, df)
# ['A/code', 'A/value', 'C/code', 'C/value']
"""
static_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null()
Expand Down Expand Up @@ -64,48 +71,73 @@ def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> list[str]:
return convert_to_df(combined_freqs)


def convert_to_freq_dict(df: pl.LazyFrame) -> dict:
def convert_to_freq_dict(df: pl.LazyFrame) -> dict[str, dict[int, int]]:
"""Converts a DataFrame to a dictionary of frequencies.
This function converts a DataFrame to a dictionary of frequencies, where the keys are the
column names and the values are dictionaries of code frequencies.
Args:
- df (pl.DataFrame): The DataFrame to be converted.
df: The DataFrame to be converted.
Returns:
- dict: A dictionary of frequencies, where the keys are the column names and the values are
dictionaries of code frequencies.
A dictionary where keys are column names and values are
dictionaries of code frequencies.
Raises:
ValueError: If the DataFrame does not have the expected columns "code" and "count".
Example:
# >>> import polars as pl
# >>> df = pl.DataFrame({
# ... "code": [1, 2, 3, 4, 5],
# ... "value": [10, 20, 30, 40, 50]
# ... })
# >>> convert_to_freq_dict(df)
# {'code': {1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, 'value': {10: 1, 20: 1, 30: 1, 40: 1, 50: 1}}
# >>> import polars as pl
# >>> df = pl.DataFrame({
# ... "code": [1, 2, 3, 4, 5],
# ... "value": [10, 20, 30, 40, 50]
# ... })
# >>> convert_to_freq_dict(df)
# {'code': {1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, 'value': {10: 1, 20: 1, 30: 1, 40: 1, 50: 1}}
"""
if not df.columns == ["code", "count"]:
raise ValueError(f"DataFrame must have columns 'code' and 'count', but has columns {df.columns}!")
return dict(df.collect().iter_rows())


def get_feature_columns(fp):
def get_feature_columns(fp: Path) -> list[str]:
"""Retrieves feature column names from a parquet file.
Args:
fp: File path to the Parquet data.
Returns:
Sorted list of column names.
"""
return sorted(list(convert_to_freq_dict(pl.scan_parquet(fp)).keys()))


def get_feature_freqs(fp):
def get_feature_freqs(fp: Path) -> dict[str, int]:
"""Retrieves feature frequencies from a parquet file.
Args:
fp: File path to the Parquet data.
Returns:
Dictionary of feature frequencies.
"""
return convert_to_freq_dict(pl.scan_parquet(fp))


def filter_to_codes(
allowed_codes: list[str] | None,
min_code_inclusion_frequency: int,
code_metadata_fp: Path,
):
"""Returns intersection of allowed codes if they are specified, and filters to codes based on inclusion
frequency."""
) -> list[str]:
"""Filters and returns codes based on allowed list and minimum frequency.
Args:
allowed_codes: List of allowed codes, None means all codes are allowed.
min_code_inclusion_frequency: Minimum frequency a code must have to be included.
code_metadata_fp: Path to the metadata file containing code information.
Returns:
Sorted list of the intersection of allowed codes (if they are specified) and filters based on
inclusion frequency.
"""
if allowed_codes is None:
allowed_codes = get_feature_columns(code_metadata_fp)
feature_freqs = get_feature_freqs(code_metadata_fp)
Expand All @@ -129,7 +161,15 @@ def filter_to_codes(
# OmegaConf.register_new_resolver("filter_to_codes", filter_to_codes)


def clear_code_aggregation_suffix(code):
def clear_code_aggregation_suffix(code: str) -> str:
"""Removes aggregation suffixes from code strings.
Args:
code: Code string to be cleared.
Returns:
Code string without aggregation suffixes.
"""
if code.endswith("/code"):
return code[:-5]
elif code.endswith("/value"):
Expand All @@ -140,36 +180,40 @@ def clear_code_aggregation_suffix(code):
return code[:-13]


def filter_parquet(fp, allowed_codes: list[str]):
"""Loads Parquet with Polars and filters to allowed codes.
def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
"""Loads and filters a Parquet file with Polars to include only specified codes and removes rare
codes/values.
Args:
fp: Path to the Meds cohort shard
allowed_codes: List of codes to filter to.
Expect:
>>> from tempfile import NamedTemporaryFile
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
│ code ┆ timestamp ┆ numerical_value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═════════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴─────────────────┘
>>> fp.close()
fp: Path to the Parquet file of a Meds cohort shard.
allowed_codes: List of codes to filter by.
Returns:
pl.LazyFrame: A filtered LazyFrame containing only the allowed and not rare codes/values.
Examples:
>>> from tempfile import NamedTemporaryFile
>>> fp = NamedTemporaryFile()
>>> pl.DataFrame({
... "code": ["A", "A", "A", "A", "D", "D", "E", "E"],
... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"],
... "numerical_value": [1, None, 2, 2, None, 5, None, 3]
... }).write_parquet(fp.name)
>>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect()
shape: (6, 3)
┌──────┬────────────┬─────────────────┐
│ code ┆ timestamp ┆ numerical_value │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞══════╪════════════╪═════════════════╡
│ A ┆ 2021-01-01 ┆ null │
│ A ┆ 2021-01-01 ┆ null │
│ D ┆ null ┆ null │
│ D ┆ null ┆ null │
│ E ┆ 2021-01-03 ┆ null │
│ E ┆ 2021-01-04 ┆ 3 │
└──────┴────────────┴─────────────────┘
>>> fp.close()
"""
df = pl.scan_parquet(fp)
# Drop values that are rare
Expand Down

0 comments on commit 8d3c5f9

Please sign in to comment.