Skip to content

Commit

Permalink
Merge pull request #9 from mmcdermott/improve_docstrings_and_tests
Browse files Browse the repository at this point in the history
Improve docstrings and tests
  • Loading branch information
mmcdermott authored Jun 13, 2024
2 parents 9bd3081 + a3a928b commit 4eda4dd
Show file tree
Hide file tree
Showing 13 changed files with 508 additions and 590 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: Install packages
run: |
pip install -e .[tests]
pip install .[tests]
#----------------------------------------------
# run test suite
Expand Down
182 changes: 104 additions & 78 deletions src/MEDS_tabular_automl/describe_codes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pathlib import Path

import polars as pl
from omegaconf import DictConfig

from MEDS_tabular_automl.utils import DF_T, get_feature_names

Expand All @@ -14,32 +13,92 @@ def convert_to_df(freq_dict: dict[str, int]) -> pl.DataFrame:
Returns:
A DataFrame with two columns, "code" and "count".
TODOs:
- Eliminate this function and just use a DataFrame throughout. See #14
- Use categorical types for `code` instead of strings.
Examples:
>>> convert_to_df({"A": 1, "B": 2, "C": 3})
shape: (3, 2)
┌──────┬───────┐
│ code ┆ count │
│ --- ┆ --- │
│ str ┆ i64 │
╞══════╪═══════╡
│ A ┆ 1 │
│ B ┆ 2 │
│ C ┆ 3 │
└──────┴───────┘
"""
return pl.DataFrame([[col, freq] for col, freq in freq_dict.items()], schema=["code", "count"])


def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> pl.DataFrame:
def convert_to_freq_dict(df: pl.LazyFrame) -> dict[str, dict[int, int]]:
"""Converts a DataFrame to a dictionary of frequencies.
Args:
df: The DataFrame to be converted.
Returns:
A dictionary where keys are column names and values are
dictionaries of code frequencies.
Raises:
ValueError: If the DataFrame does not have the expected columns "code" and "count".
TODOs:
- Eliminate this function and just use a DataFrame throughout. See #14
Example:
>>> import polars as pl
>>> data = pl.DataFrame({"code": [1, 2, 3, 4, 5], "count": [10, 20, 30, 40, 50]}).lazy()
>>> convert_to_freq_dict(data)
{1: 10, 2: 20, 3: 30, 4: 40, 5: 50}
>>> convert_to_freq_dict(pl.DataFrame({"code": ["A", "B", "C"], "value": [1, 2, 3]}).lazy())
Traceback (most recent call last):
...
ValueError: DataFrame must have columns 'code' and 'count', but has columns ['code', 'value']!
"""
if not df.columns == ["code", "count"]:
raise ValueError(f"DataFrame must have columns 'code' and 'count', but has columns {df.columns}!")
return dict(df.collect().iter_rows())


def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame:
"""Generates a DataFrame containing the frequencies of codes and numerical values under different
aggregations by computing frequency counts for certain attributes and organizing the results into specific
categories based on the dataset's features.
Args:
cfg: Configuration dictionary specifying how features should be evaluated and aggregated.
shard_df: A DataFrame containing the data to be analyzed and split (e.g., 'train', 'test').
Returns:
A tuple containing a list of feature columns and a dictionary of code properties identified
during the evaluation.
Examples:
# >>> import polars as pl
# >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
# ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None],
# ... 'numerical_value': [1, None, 2, 2, None, None, 3]}
# >>> df = pl.DataFrame(data).lazy()
# >>> aggs = ['value/sum', 'code/count']
# >>> compute_feature_frequencies(aggs, df)
# ['A/code', 'A/value', 'C/code', 'C/value']
>>> from datetime import datetime
>>> data = pl.DataFrame({
... 'patient_id': [1, 1, 2, 2, 3, 3, 3],
... 'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'],
... 'timestamp': [
... None,
... datetime(2021, 1, 1),
... None,
... None,
... datetime(2021, 1, 3),
... datetime(2021, 1, 4),
... None
... ],
... 'numerical_value': [1, None, 2, 2, None, None, 3]
... }).lazy()
>>> assert (
... convert_to_freq_dict(compute_feature_frequencies(data).lazy()) == {
... 'B/static/present': 2, 'C/static/present': 1, 'A/static/present': 1, 'B/static/first': 2,
... 'C/static/first': 1, 'A/static/first': 1, 'A/code': 1, 'C/code': 2
... }
... )
"""
static_df = shard_df.filter(
pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null()
Expand Down Expand Up @@ -71,33 +130,6 @@ def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> pl.DataFrame
return convert_to_df(combined_freqs)


def convert_to_freq_dict(df: pl.LazyFrame) -> dict[str, dict[int, int]]:
"""Converts a DataFrame to a dictionary of frequencies.
Args:
df: The DataFrame to be converted.
Returns:
A dictionary where keys are column names and values are
dictionaries of code frequencies.
Raises:
ValueError: If the DataFrame does not have the expected columns "code" and "count".
Example:
# >>> import polars as pl
# >>> df = pl.DataFrame({
# ... "code": [1, 2, 3, 4, 5],
# ... "value": [10, 20, 30, 40, 50]
# ... })
# >>> convert_to_freq_dict(df)
# {'code': {1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, 'value': {10: 1, 20: 1, 30: 1, 40: 1, 50: 1}}
"""
if not df.columns == ["code", "count"]:
raise ValueError(f"DataFrame must have columns 'code' and 'count', but has columns {df.columns}!")
return dict(df.collect().iter_rows())


def get_feature_columns(fp: Path) -> list[str]:
"""Retrieves feature column names from a parquet file.
Expand All @@ -106,8 +138,15 @@ def get_feature_columns(fp: Path) -> list[str]:
Returns:
Sorted list of column names.
Examples:
>>> from tempfile import NamedTemporaryFile
>>> with NamedTemporaryFile() as f:
... pl.DataFrame({"code": ["E", "D", "A"], "count": [1, 3, 2]}).write_parquet(f.name)
... get_feature_columns(f.name)
['A', 'D', 'E']
"""
return sorted(list(convert_to_freq_dict(pl.scan_parquet(fp)).keys()))
return sorted(list(get_feature_freqs(fp).keys()))


def get_feature_freqs(fp: Path) -> dict[str, int]:
Expand All @@ -118,47 +157,15 @@ def get_feature_freqs(fp: Path) -> dict[str, int]:
Returns:
Dictionary of feature frequencies.
"""
return convert_to_freq_dict(pl.scan_parquet(fp))

def filter_to_codes(
allowed_codes: list[str] | None,
min_code_inclusion_frequency: int,
code_metadata_fp: Path,
) -> list[str]:
"""Filters and returns codes based on allowed list and minimum frequency.
Args:
allowed_codes: List of allowed codes, None means all codes are allowed.
min_code_inclusion_frequency: Minimum frequency a code must have to be included.
code_metadata_fp: Path to the metadata file containing code information.
Returns:
Sorted list of the intersection of allowed codes (if they are specified) and filters based on
inclusion frequency.
Examples:
>>> from tempfile import NamedTemporaryFile
>>> with NamedTemporaryFile() as f:
... pl.DataFrame({"code": ["E", "D", "A"], "count": [1, 3, 2]}).write_parquet(f.name)
... get_feature_freqs(f.name)
{'E': 1, 'D': 3, 'A': 2}
"""
if allowed_codes is None:
allowed_codes = get_feature_columns(code_metadata_fp)
feature_freqs = get_feature_freqs(code_metadata_fp)
allowed_codes_set = set(allowed_codes)

filtered_codes = [
code
for code, freq in feature_freqs.items()
if freq >= min_code_inclusion_frequency and code in allowed_codes_set
]
return sorted(filtered_codes)

# code_freqs = {
# code: freq
# for code, freq in feature_freqs.items()
# if (freq >= min_code_inclusion_frequency and code in set(allowed_codes))
# }
# return sorted([code for code, freq in code_freqs.items() if freq >= min_code_inclusion_frequency])


# OmegaConf.register_new_resolver("filter_to_codes", filter_to_codes)
return convert_to_freq_dict(pl.scan_parquet(fp))


def clear_code_aggregation_suffix(code: str) -> str:
Expand All @@ -169,6 +176,23 @@ def clear_code_aggregation_suffix(code: str) -> str:
Returns:
Code string without aggregation suffixes.
Raises:
ValueError: If the code does not have a recognized aggregation suffix.
Examples:
>>> clear_code_aggregation_suffix("A/code")
'A'
>>> clear_code_aggregation_suffix("A/value")
'A'
>>> clear_code_aggregation_suffix("A/static/present")
'A'
>>> clear_code_aggregation_suffix("A/static/first")
'A'
>>> clear_code_aggregation_suffix("A")
Traceback (most recent call last):
...
ValueError: Code A does not have a recognized aggregation suffix!
"""
if code.endswith("/code"):
return code[:-5]
Expand All @@ -178,6 +202,8 @@ def clear_code_aggregation_suffix(code: str) -> str:
return code[:-15]
elif code.endswith("/static/first"):
return code[:-13]
else:
raise ValueError(f"Code {code} does not have a recognized aggregation suffix!")


def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame:
Expand Down
75 changes: 43 additions & 32 deletions src/MEDS_tabular_automl/generate_static_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
efficient data manipulation.
Functions:
- _summarize_static_measurements: Summarizes static measurements from a given DataFrame.
- convert_to_matrix: Converts a Polars DataFrame to a sparse matrix.
- get_sparse_static_rep: Merges static and time-series dataframes into a sparse representation.
- summarize_static_measurements: Summarizes static measurements from a given DataFrame.
- get_flat_static_rep: Produces a tabular representation of static data features.
"""

Expand All @@ -14,7 +16,6 @@
from scipy.sparse import coo_array, csr_array

from MEDS_tabular_automl.utils import (
DF_T,
STATIC_CODE_AGGREGATION,
STATIC_VALUE_AGGREGATION,
get_events_df,
Expand All @@ -24,8 +25,17 @@
)


def convert_to_matrix(df, num_events, num_features):
"""Converts a Polars DataFrame to a sparse matrix."""
def convert_to_matrix(df: pl.DataFrame, num_events: int, num_features: int) -> csr_array:
"""Converts a Polars DataFrame to a sparse matrix.
Args:
df: The DataFrame to convert.
num_events: Number of events to set matrix dimension.
num_features: Number of features to set matrix dimension.
Returns:
A sparse matrix representation of the DataFrame.
"""
dense_matrix = df.drop("patient_id").collect().to_numpy()
data_list = []
rows = []
Expand All @@ -41,18 +51,19 @@ def convert_to_matrix(df, num_events, num_features):
return matrix


def get_sparse_static_rep(static_features, static_df, meds_df, feature_columns) -> coo_array:
"""Merges static and time-series dataframes.
This function merges the static and time-series dataframes based on the patient_id column.
def get_sparse_static_rep(
static_features: list[str], static_df: pl.DataFrame, meds_df: pl.DataFrame, feature_columns: list[str]
) -> coo_array:
"""Merges static and time-series dataframes into a sparse representation based on the patient_id column.
Args:
- feature_columns (List[str]): A list of feature columns to include in the merged dataframe.
- static_df (pd.DataFrame): A dataframe containing static features.
- ts_df (pd.DataFrame): A dataframe containing time-series features.
static_features: A list of static feature names.
static_df: A DataFrame containing static features.
meds_df: A DataFrame containing time-series features.
feature_columns (list[str]): A list of feature columns to include in the merged DataFrame.
Returns:
- pd.DataFrame: A merged dataframe containing static and time-series features.
A sparse array representation of the merged static and time-series features.
"""
# Make static data sparse and merge it with the time-series data
logger.info("Make static data sparse and merge it with the time-series data")
Expand Down Expand Up @@ -85,22 +96,21 @@ def get_sparse_static_rep(static_features, static_df, meds_df, feature_columns)
def summarize_static_measurements(
agg: str,
feature_columns: list[str],
df: DF_T,
df: pl.LazyFrame,
) -> pl.LazyFrame:
"""Aggregates static measurements for feature columns that are marked as 'present' or 'first'.
Parameters:
- feature_columns (list[str]): List of feature column identifiers that are specifically marked
for staticanalysis.
- df (DF_T): Data frame from which features will be extracted and summarized.
Returns:
- pl.LazyFrame: A LazyFrame containing the summarized data pivoted by 'patient_id'
for each static feature.
This function first filters for features that need to be recorded as the first occurrence
or simply as present, then performs a pivot to reshape the data for each patient, providing
a tabular format where each row represents a patient and each column represents a static feature.
Args:
agg: The type of aggregation ('present' or 'first').
feature_columns: A list of feature column identifiers marked for static analysis.
df: The DataFrame from which features will be extracted and summarized.
Returns:
A LazyFrame containing summarized data pivoted by 'patient_id' for each static feature.
"""
if agg == STATIC_VALUE_AGGREGATION:
static_features = get_feature_names(agg=agg, feature_columns=feature_columns)
Expand Down Expand Up @@ -157,20 +167,21 @@ def summarize_static_measurements(
def get_flat_static_rep(
agg: str,
feature_columns: list[str],
shard_df: DF_T,
shard_df: pl.LazyFrame,
) -> coo_array:
"""Produces a raw representation for static data from a specified shard DataFrame.
"""Produces a sparse representation for static data from a specified shard DataFrame.
Parameters:
- feature_columns (list[str]): List of feature columns to include in the static representation.
- shard_df (DF_T): The shard DataFrame containing patient data.
This function selects the appropriate static features, summarizes them using
summarize_static_measurements, and then normalizes the resulting data to ensure
it is suitable for further analysis or machine learning tasks.
Returns:
- pl.LazyFrame: A LazyFrame that includes all static features for the data provided.
Args:
agg: The aggregation method for static data.
feature_columns: A list of feature columns to include.
shard_df: The shard DataFrame containing the patient data.
This function selects the appropriate static features, summarizes them using
_summarize_static_measurements, and then normalizes the resulting data to ensure it is
suitable for further analysis or machine learning tasks.
Returns:
A sparse array representing the static features for the provided shard of data.
"""
static_features = get_feature_names(agg=agg, feature_columns=feature_columns)
static_measurements = summarize_static_measurements(agg, static_features, df=shard_df)
Expand Down
Loading

0 comments on commit 4eda4dd

Please sign in to comment.