Skip to content

Commit

Permalink
passing unit tests for sparse aggregations (only code/count and value…
Browse files Browse the repository at this point in the history
…/sum implemented at the moment)
  • Loading branch information
Oufattole committed May 29, 2024
1 parent 6f3b1ec commit 6753609
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
1 change: 1 addition & 0 deletions scripts/summarize_over_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def summarize_ts_data_over_windows(
cfg.window_sizes,
cfg.aggs,
)
assert summary_df.shape[1] > 2, "No data found in the summarized dataframe"

logger.info("Writing pivot file")
write_df(summary_df, pivot_fp, do_overwrite=cfg.do_overwrite)
Expand Down
9 changes: 7 additions & 2 deletions src/MEDS_tabular_automl/generate_summarized_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import polars as pl
from scipy.sparse import coo_matrix

from MEDS_tabular_automl.generate_ts_features import get_ts_columns

CODE_AGGREGATIONS = [
"code/count",
]
Expand Down Expand Up @@ -240,7 +242,7 @@ def generate_summary(
>>> wide_df['timestamp'] = pd.to_datetime(wide_df['timestamp'])
>>> for col in ["A/code", "B/code", "A/value", "B/value"]:
... wide_df[col] = pd.arrays.SparseArray(wide_df[col])
>>> feature_columns = ["A/code", "B/code", "A/value", "B/value"]
>>> feature_columns = ["A/code/count", "B/code/count", "A/value/sum", "B/value/sum"]
>>> aggregations = ["code/count", "value/sum"]
>>> window_sizes = ["full", "1d"]
>>> generate_summary(feature_columns, wide_df, window_sizes, aggregations)[
Expand All @@ -264,14 +266,17 @@ def generate_summary(
0 NaN NaN 0
"""
df = df.sort_values(["patient_id", "timestamp"])
assert len(feature_columns), "feature_columns must be a non-empty list"
ts_columns = get_ts_columns(feature_columns)
code_value_ts_columns = [f"{c}/code" for c in ts_columns] + [f"{c}/value" for c in ts_columns]
final_columns = []
out_dfs = []
# Generate summaries for each window size and aggregation
for window_size in window_sizes:
for agg in aggregations:
code_type, agg_name = agg.split("/")
final_columns.extend(
[f"{window_size}/{c}/{agg_name}" for c in feature_columns if c.endswith(code_type)]
[f"{window_size}/{c}/{agg_name}" for c in code_value_ts_columns if c.endswith(code_type)]
)
# only iterate through code_types that exist in the dataframe columns
if any([c.endswith(code_type) for c in df.columns]):
Expand Down
34 changes: 17 additions & 17 deletions tests/test_tabularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from loguru import logger

from scripts.identify_columns import store_columns
from scripts.summarize_over_windows import summarize_ts_data_over_windows
from scripts.tabularize_static import tabularize_static_data
from scripts.tabularize_ts import tabularize_ts_data

Expand Down Expand Up @@ -126,9 +127,10 @@ def test_tabularize():
"tabularized_data_dir": str(tabularized_data_dir.resolve()),
"min_code_inclusion_frequency": 1,
"window_sizes": ["30d", "365d", "full"],
"aggs": ["code/count", "value/sum"],
"codes": None,
"n_patients_per_sub_shard": 2,
"do_overwrite": False,
"do_overwrite": True,
"do_update": True,
"seed": 1,
"hydra.verbose": True,
Expand Down Expand Up @@ -156,19 +158,17 @@ def test_tabularize():
]
assert set(actual_files) == set(expected_files)

# summarize_ts_data_over_windows(cfg)
# # confirm summary files exist:
# actual_files = [
# (f.parent.stem, f.stem) for f in list(tabularized_data_dir.glob("summary/*/*.parquet"))
# ]
# expected_files = [
# ("train", "1"),
# ("train", "0"),
# ("held_out", "0"),
# ("tuning", "0"),
# ]
# assert set(actual_files) == set(expected_files)
# for f in list(tabularized_data_dir.glob("summary/*/*.parquet")):
# df = pl.read_parquet(f)
# assert df.shape[0] > 0
# assert df.columns == ["hi"]
summarize_ts_data_over_windows(cfg)
# confirm summary files exist:
actual_files = [(f.parent.stem, f.stem) for f in list(tabularized_data_dir.glob("ts/*/*.parquet"))]
expected_files = [
("train", "1"),
("train", "0"),
("held_out", "0"),
("tuning", "0"),
]
assert set(actual_files) == set(expected_files)
for f in list(tabularized_data_dir.glob("summary/*/*.parquet")):
df = pl.read_parquet(f)
assert df.shape[0] > 0
assert df.columns == ["hi"]

0 comments on commit 6753609

Please sign in to comment.