Skip to content

Commit

Permalink
Enhance data aggregation framework with dynamic window and aggregatio…
Browse files Browse the repository at this point in the history
…n handling

- Introduce VALID_AGGREGATIONS to define permissible aggregations.
- Implement  to generate dynamic column aliases based on window size and aggregation.
- Extend  for dynamic expression creation based on aggregation type and window size, handling both cumulative and windowed aggregations.
- Enhance  to apply specified aggregations over defined window sizes, ensuring comprehensive data summarization.
- Update  to handle multiple dataframes, aggregate data using specified window sizes and aggregations, and ensure inclusion of all specified feature columns, adding missing ones with default values.
- Add extensive doctests to ensure accuracy of the summarization functions, demonstrating usage with both code and value data types.
  • Loading branch information
Oufattole committed May 27, 2024
1 parent cd067f8 commit c8ca3bb
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 23 deletions.
10 changes: 3 additions & 7 deletions configs/tabularize.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,14 @@ window_sizes: ???
codes: null
aggs:
- "code/count"
- "code/time_since_last"
- "code/time_since_first"
- "code/present"
- "value/count"
- "value/present"
- "value/sum"
- "value/sum_sqd"
- "value/min"
- "value/time_since_min"
- "value/max"
- "value/time_since_max"
- "value/last"
- "value/slope"
- "value/intercept"
- "value/first"
dynamic_threshold: 0.01
numerical_value_threshold: 0.1

Expand Down
46 changes: 44 additions & 2 deletions scripts/summarize_over_windows.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""WIP."""


from pathlib import Path

import hydra
import polars as pl
from loguru import logger
from omegaconf import DictConfig

from MEDS_tabular_automl.utils import setup_environment
from MEDS_tabular_automl.generate_summarized_reps import generate_summary
from MEDS_tabular_automl.utils import setup_environment, write_df


@hydra.main(version_base=None, config_path="../configs", config_name="tabularize")
Expand Down Expand Up @@ -50,4 +55,41 @@ def summarize_ts_data_over_windows(
.. _link: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.groupby_rolling.html # noqa: E501
"""
setup_environment(cfg)
flat_dir, _, feature_columns = setup_environment(cfg)

# Assuming MEDS_cohort_dir is correctly defined somewhere above this snippet
ts_dir = Path(cfg.tabularized_data_dir) / "ts"
ts_fps = list(ts_dir.glob("*/*.parquet"))
splits = {fp.parent.stem for fp in ts_fps}

split_to_pair_fps = {}
for split in splits:
# Categorize files by identifier (base name without '_code' or '_value') using a list comprehension
categorized_files = {
file.stem.rsplit("_", 1)[0]: {"code": None, "value": None}
for file in ts_fps
if file.parent.stem == split
}
for file in ts_fps:
if file.parent.stem == split:
identifier = file.stem.rsplit("_", 1)[0]
suffix = file.stem.split("_")[-1] # 'code' or 'value'
categorized_files[identifier][suffix] = file

# Process categorized files into pairs ensuring code is first and value is second
code_value_pairs = [
(info["code"], info["value"])
for info in categorized_files.values()
if info["code"] is not None and info["value"] is not None
]

split_to_pair_fps[split] = code_value_pairs

# Example use of split_to_pair_fps
for split, pairs in split_to_pair_fps.items():
logger.info(f"Processing {split}:")
for code_file, value_file in pairs:
logger.info(f" - Code file: {code_file}, Value file: {value_file}")
summary_df = generate_summary(pl.scan_parquet(code_file), pl.scan_parquet(value_file))
shard_number = code_file.stem.rsplit("_", 1)[0]
write_df(summary_df, flat_dir / split / f"{shard_number}.parquet")
4 changes: 2 additions & 2 deletions src/MEDS_tabular_automl/generate_static_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from MEDS_tabular_automl.utils import DF_T, add_missing_cols, parse_flat_feature_column


def _summarize_static_measurements(
def summarize_static_measurements(
feature_columns: list[str],
df: DF_T,
) -> pl.LazyFrame:
Expand Down Expand Up @@ -98,7 +98,7 @@ def get_flat_static_rep(
suitable for further analysis or machine learning tasks.
"""
static_features = [c for c in feature_columns if c.startswith("STATIC_")]
static_measurements = _summarize_static_measurements(static_features, df=shard_df)
static_measurements = summarize_static_measurements(static_features, df=shard_df)
# fill up missing feature columns with nulls
normalized_measurements = add_missing_cols(
static_measurements,
Expand Down
250 changes: 250 additions & 0 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from collections.abc import Callable

import polars as pl
import polars.selectors as cs

from MEDS_tabular_automl.utils import DF_T

VALID_AGGREGATIONS = [
"code/count",
"value/count",
"value/has_values_count",
"value/sum",
"value/sum_sqd",
"value/min",
"value/max",
"value/first",
]


def time_aggd_col_alias_fntr(window_size: str, agg: str) -> Callable[[str], str]:
assert agg is not None, "agg must be provided"

def f(c: str) -> str:
return "/".join([window_size] + c.split("/") + [agg])

return f


def get_agg_pl_expr(window_size: str, agg: str):
code_cols = cs.starts_with("code/")
value_cols = cs.starts_with("value/")
if window_size == "full":
match agg:
case "code/count":
return code_cols.cumsum().map_alias(time_aggd_col_alias_fntr(window_size, "count"))
case "value/count":
return (
value_cols.is_not_null()
.cumsum()
.map_alias(time_aggd_col_alias_fntr(window_size, "count"))
)
case "value/has_values_count":
return (
(value_cols.is_not_null() & value_cols.is_not_nan())
.cumsum()
.map_alias(time_aggd_col_alias_fntr(window_size, "has_values_count"))
)
case "value/sum":
return value_cols.cumsum().map_alias(time_aggd_col_alias_fntr(window_size, "sum"))
case "value/sum_sqd":
return (value_cols**2).cumsum().map_alias(time_aggd_col_alias_fntr(window_size, "sum_sqd"))
case "value/min":
value_cols.cummin().map_alias(time_aggd_col_alias_fntr(window_size, "min"))
case "value/max":
value_cols.cummax().map_alias(time_aggd_col_alias_fntr(window_size, "max"))
case _:
raise ValueError(f"Invalid aggregation `{agg}` for window_size `{window_size}`")
else:
match agg:
case "code/count":
return code_cols.sum().map_alias(time_aggd_col_alias_fntr(window_size, "count"))
case "value/count":
return (
value_cols.is_not_null().sum().map_alias(time_aggd_col_alias_fntr(window_size, "count"))
)
case "value/has_values_count":
return (
(value_cols.is_not_null() & value_cols.is_not_nan())
.sum()
.map_alias(time_aggd_col_alias_fntr(window_size, "has_values_count"))
)
case "value/sum":
return value_cols.sum().map_alias(time_aggd_col_alias_fntr(window_size, "sum"))
case "value/sum_sqd":
return (value_cols**2).sum().map_alias(time_aggd_col_alias_fntr(window_size, "sum_sqd"))
case "value/min":
value_cols.min().map_alias(time_aggd_col_alias_fntr(window_size, "min"))
case "value/max":
value_cols.max().map_alias(time_aggd_col_alias_fntr(window_size, "max"))
case _:
raise ValueError(f"Invalid aggregation `{agg}` for window_size `{window_size}`")


def _generate_summary(df: DF_T, window_size: str, agg: str) -> pl.LazyFrame:
"""Generate a summary of the data frame for a given window size and aggregation.
Args:
- df (DF_T): The data frame to summarize.
- window_size (str): The window size to use for the summary.
- agg (str): The aggregation to apply to the data frame.
Returns:
- pl.LazyFrame: The summarized data frame.
Expect:
>>> from datetime import date
>>> code_df = pl.DataFrame({"patient_id": [1, 1, 1, 2],
... "code/A": [1, 1, 0, 0],
... "code/B": [0, 0, 1, 1],
... "timestamp": [date(2021, 1, 1), date(2021, 1, 2),date(2020, 1, 3), date(2021, 1, 4)],
... }).lazy()
>>> _generate_summary(code_df.lazy(), "full", "code/count"
... ).collect().sort(["patient_id", "timestamp"])
shape: (4, 4)
┌────────────┬────────────┬───────────────────┬───────────────────┐
│ patient_id ┆ timestamp ┆ full/code/A/count ┆ full/code/B/count │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ date ┆ i64 ┆ i64 │
╞════════════╪════════════╪═══════════════════╪═══════════════════╡
│ 1 ┆ 2020-01-03 ┆ 2 ┆ 1 │
│ 1 ┆ 2021-01-01 ┆ 1 ┆ 0 │
│ 1 ┆ 2021-01-02 ┆ 2 ┆ 0 │
│ 2 ┆ 2021-01-04 ┆ 0 ┆ 1 │
└────────────┴────────────┴───────────────────┴───────────────────┘
>>> value_df = pl.DataFrame({"patient_id": [1, 1, 1, 2],
... "timestamp": [date(2021, 1, 1), date(2021, 1, 2),
... date(2020, 1, 3), date(2021, 1, 4)],
... "value/A": [1, 2, 3, None],
... "value/B": [None, None, None, 4.0],})
>>> _generate_summary(value_df.lazy(), "full", "value/sum").collect().sort(
... ["patient_id", "timestamp"])
shape: (4, 4)
┌────────────┬────────────┬──────────────────┬──────────────────┐
│ patient_id ┆ timestamp ┆ full/value/A/sum ┆ full/value/B/sum │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ date ┆ i64 ┆ f64 │
╞════════════╪════════════╪══════════════════╪══════════════════╡
│ 1 ┆ 2020-01-03 ┆ 6 ┆ null │
│ 1 ┆ 2021-01-01 ┆ 1 ┆ null │
│ 1 ┆ 2021-01-02 ┆ 3 ┆ null │
│ 2 ┆ 2021-01-04 ┆ null ┆ 4.0 │
└────────────┴────────────┴──────────────────┴──────────────────┘
>>> _generate_summary(value_df.lazy(), "1d", "value/count").collect().sort(
... ["patient_id", "timestamp"])
shape: (4, 4)
┌────────────┬────────────┬──────────────────┬──────────────────┐
│ patient_id ┆ timestamp ┆ 1d/value/A/count ┆ 1d/value/B/count │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ date ┆ u32 ┆ u32 │
╞════════════╪════════════╪══════════════════╪══════════════════╡
│ 1 ┆ 2020-01-03 ┆ 1 ┆ 0 │
│ 1 ┆ 2021-01-01 ┆ 1 ┆ 0 │
│ 1 ┆ 2021-01-02 ┆ 1 ┆ 0 │
│ 2 ┆ 2021-01-04 ┆ 0 ┆ 1 │
└────────────┴────────────┴──────────────────┴──────────────────┘
"""
assert agg in VALID_AGGREGATIONS, f"Invalid aggregation: {agg}"
assert agg.split("/")[0] in [
c.split("/")[0] for c in df.columns
], f"df is invalid, no column with prefix: `{agg.split('/')[0]}`"

if window_size == "full":
out_df = df.groupby("patient_id").agg(
"timestamp",
get_agg_pl_expr(window_size, agg),
)
out_df = out_df.explode(*[c for c in out_df.columns if c != "patient_id"])
else:
out_df = (
df.sort(["patient_id", "timestamp"])
.groupby_rolling(
index_column="timestamp",
by="patient_id",
period=window_size,
)
.agg(
get_agg_pl_expr(window_size, agg),
)
)

return out_df


def generate_summary(
feature_columns: list[str], dfs: list[pl.LazyFrame], window_sizes: list[str], aggregations: list[str]
) -> pl.LazyFrame:
"""Generate a summary of the data frame for given window sizes and aggregations.
This function processes a dataframe to apply specified aggregations over defined window sizes.
It then joins the resulting frames on 'patient_id' and 'timestamp', and ensures all specified
feature columns exist in the final output, adding missing ones with default values.
Args:
feature_columns (list[str]): List of all feature columns that must exist in the final output.
df (list[pl.LazyFrame]): The input dataframes to process, expected to be length 2 list with code_df
(pivoted shard with binary presence of codes) and value_df (pivoted shard with numerical values
for each code).
window_sizes (list[str]): List of window sizes to apply for summarization.
aggregations (list[str]): List of aggregations to perform within each window size.
Returns:
pl.LazyFrame: A LazyFrame containing the summarized data with all required features present.
Expect:
>>> from datetime import date
>>> value_df = pl.DataFrame({"patient_id": [1, 1, 1, 2],
... "timestamp": [date(2021, 1, 1), date(2021, 1, 2),date(2020, 1, 3), date(2021, 1, 4)],
... "value/A": [1, 2, 3, None],
... "value/B": [None, None, None, 4.0],})
>>> code_df = pl.DataFrame({"patient_id": [1, 1, 1, 2],
... "code/A": [1, 1, 0, 0],
... "code/B": [0, 0, 1, 1],
... "timestamp": [date(2021, 1, 1), date(2021, 1, 2),date(2020, 1, 3), date(2021, 1, 5)],
... }).lazy()
>>> feature_columns = ["code/A", "code/B", "value/A", "value/B"]
>>> aggregations = ["code/count", "value/sum"]
>>> window_sizes = ["full", "1d"]
>>> out_df = generate_summary(feature_columns, [value_df.lazy(), code_df.lazy()],
... window_sizes, aggregations).collect().sort(["patient_id", "timestamp"])
>>> print(out_df.shape)
(5, 10)
>>> for c in out_df.columns: print(c)
patient_id
timestamp
1d/code/A/count
1d/code/B/count
1d/value/A/sum
1d/value/B/sum
full/code/A/count
full/code/B/count
full/value/A/sum
full/value/B/sum
"""
final_columns = []
out_dfs = []
# Generate summaries for each window size and aggregation
for window_size in window_sizes:
for agg in aggregations:
code_type, agg_name = agg.split("/")
final_columns.extend(
[f"{window_size}/{c}/{agg_name}" for c in feature_columns if c.startswith(code_type)]
)
for df in dfs:
if agg.split("/")[0] in [c.split("/")[0] for c in df.columns]:
out_df = _generate_summary(df, window_size, agg)
out_dfs.append(out_df)

final_columns = sorted(final_columns)
# Combine all dataframes using successive joins
result_df = out_dfs[0]
for df in out_dfs[1:]:
result_df = result_df.join(df, on=["patient_id", "timestamp"], how="outer", coalesce=True)

# Add in missing feature columns with default values
existing_columns = result_df.columns
for column in final_columns:
if column not in existing_columns:
result_df = result_df.with_columns(pl.lit(None).alias(column))
result_df = result_df.select(pl.col(*["patient_id", "timestamp"], *final_columns))
return result_df
12 changes: 0 additions & 12 deletions src/MEDS_tabular_automl/generate_ts_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,6 @@

from MEDS_tabular_automl.utils import DF_T

VALID_AGGREGATIONS = [
"sum",
"sum_sqd",
"min",
"max",
"value",
"first",
"present",
"count",
"has_values_count",
]


def summarize_dynamic_measurements(
ts_columns: list[str],
Expand Down

0 comments on commit c8ca3bb

Please sign in to comment.