diff --git a/scripts/identify_columns.py b/scripts/identify_columns.py index d74f811..4334df8 100644 --- a/scripts/identify_columns.py +++ b/scripts/identify_columns.py @@ -98,7 +98,7 @@ def store_columns( .. _link: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.groupby_rolling.html # noqa: E501 """ # create output dir - flat_dir = Path(cfg.tabularized_data_dir) / "flat_reps" + flat_dir = Path(cfg.tabularized_data_dir) flat_dir.mkdir(exist_ok=True, parents=True) # load MEDS data diff --git a/scripts/tabularize_ts.py b/scripts/tabularize_ts.py index 2d9ac95..33e9dec 100644 --- a/scripts/tabularize_ts.py +++ b/scripts/tabularize_ts.py @@ -1,25 +1,21 @@ -"""WIP.""" import hydra from omegaconf import DictConfig +from tqdm import tqdm -from MEDS_tabular_automl.utils import setup_environment +from MEDS_tabular_automl.generate_ts_features import get_flat_ts_rep +from MEDS_tabular_automl.utils import setup_environment, write_df @hydra.main(version_base=None, config_path="../configs", config_name="tabularize") def tabularize_ts_data( cfg: DictConfig, ): - """Writes a flat (historically summarized) representation of the dataset to disk. + """Processes a medical dataset to generates and stores flat representatiosn of time-series data. - This file caches a set of files useful for building flat representations of the dataset to disk, - suitable for, e.g., sklearn style modeling for downstream tasks. It will produce a few sets of files: - - * A new directory ``self.config.save_dir / "flat_reps"`` which contains the following: - * A subdirectory ``raw`` which contains: (1) a json file with the configuration arguments and (2) a - set of parquet files containing flat (e.g., wide) representations of summarized events per subject, - broken out by split and subject chunk. - * A set of subdirectories ``past/*`` which contains summarized views over the past ``*`` time period - per subject per event, for all time periods in ``window_sizes``, if any. + This function handles MEDS format data and pivots tables to create two types of data files + with patient_id and timestamp indexes: + code data: containing a column for every code and 1 and 0 values indicating presence + value data: containing a column for every code which the numerical value observed. Args: cfg: @@ -35,8 +31,8 @@ def tabularize_ts_data( specified in this argument. These are strings specifying time deltas, using this syntax: `link`_. Each window size will be summarized to a separate directory, and will share the same subject file split as is used in the raw representation files. - codes: A list of codes to include in the flat representation. If `None`, all codes will be included - in the flat representation. + codes: A list of codes to include in the flat representation. If `None`, all codes will be + included in the flat representation. aggs: A list of aggregations to apply to the raw representation. Must have length greater than 0. n_patients_per_sub_shard: The number of subjects that should be included in each output file. Lowering this number increases the number of files written, making the process of creating and @@ -45,7 +41,29 @@ def tabularize_ts_data( directory. do_update: bool = True seed: The seed to use for random number generation. - - .. _link: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.groupby_rolling.html # noqa: E501 """ - setup_environment(cfg) + flat_dir, split_to_df, feature_columns = setup_environment(cfg) + # Produce ts representation + ts_subdir = flat_dir / "ts" + + for sp, subjects_dfs in tqdm(list(split_to_df.items()), desc="Flattening Splits"): + sp_dir = ts_subdir / sp + + for i, shard_df in enumerate(tqdm(subjects_dfs, desc="Subject chunks", leave=False)): + code_fp = sp_dir / f"{i}_code.parquet" + value_fp = sp_dir / f"{i}_value.parquet" + if code_fp.exists() or value_fp.exists(): + if cfg.do_update: + continue + elif not cfg.do_overwrite: + raise FileExistsError( + f"do_overwrite is {cfg.do_overwrite} and {code_fp.exists()}" + f" or {value_fp.exists()} exists!" + ) + + code_df, value_df = get_flat_ts_rep( + feature_columns=feature_columns, + shard_df=shard_df, + ) + write_df(code_df, code_fp, do_overwrite=cfg.do_overwrite) + write_df(value_df, value_fp, do_overwrite=cfg.do_overwrite) diff --git a/src/MEDS_tabular_automl/generate_ts_features.py b/src/MEDS_tabular_automl/generate_ts_features.py index 9d5956f..12ac571 100644 --- a/src/MEDS_tabular_automl/generate_ts_features.py +++ b/src/MEDS_tabular_automl/generate_ts_features.py @@ -1,226 +1,137 @@ -"""WIP. - -This file will be used to generate time series features from the raw data. -""" -from collections.abc import Callable -from pathlib import Path - -import numpy as np import polars as pl -import polars.selectors as cs - -from MEDS_tabular_automl.utils import DF_T, add_missing_cols, parse_flat_feature_column - - -def _summarize_dynamic_measurements( - self, - feature_columns: list[str], - include_only_subjects: set[int] | None = None, -) -> pl.LazyFrame: - if include_only_subjects is None: - df = self.dynamic_measurements_df - else: - df = self.dynamic_measurements_df.join( - self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select("event_id"), - on="event_id", - how="inner", - ) - - valid_measures = {} - for feat_col in feature_columns: - temp, meas, feat = parse_flat_feature_column(feat_col) - - if temp != "dynamic": - continue - if meas not in valid_measures: - valid_measures[meas] = set() - valid_measures[meas].add(feat) +from MEDS_tabular_automl.utils import DF_T - out_dfs = {} - for m, allowed_vocab in valid_measures.items(): - cfg = self.measurement_configs[m] - - total_observations = int( - np.ceil( - cfg.observation_rate_per_case - * cfg.observation_rate_over_cases - * sum(self.n_events_per_subject.values()) - ) - ) - - count_type = self.get_smallest_valid_uint_type(total_observations) - - if cfg.modality == "univariate_regression" and cfg.vocabulary is None: - prefix = f"dynamic/{m}/{m}" - - key_col = pl.col(m) - val_col = pl.col(m).drop_nans().cast(pl.Float32) - - out_dfs[m] = ( - df.lazy() - .select("measurement_id", "event_id", m) - .filter(pl.col(m).is_not_null()) - .groupby("event_id") - .agg( - pl.col(m).is_not_null().sum().cast(count_type).alias(f"{prefix}/count"), - ( - (pl.col(m).is_not_nan() & pl.col(m).is_not_null()) - .sum() - .cast(count_type) - .alias(f"{prefix}/has_values_count") - ), - val_col.sum().alias(f"{prefix}/sum"), - (val_col**2).sum().alias(f"{prefix}/sum_sqd"), - val_col.min().alias(f"{prefix}/min"), - val_col.max().alias(f"{prefix}/max"), - ) - ) - continue - elif cfg.modality == "multivariate_regression": - column_cols = [m, m] - values_cols = [m, cfg.values_column] - key_prefix = f"{m}_{m}_" - val_prefix = f"{cfg.values_column}_{m}_" - - key_col = cs.starts_with(key_prefix) - val_col = cs.starts_with(val_prefix).drop_nans().cast(pl.Float32) - - aggs = [ - key_col.is_not_null() - .sum() - .cast(count_type) - .map_alias(lambda c: f"dynamic/{m}/{c.replace(key_prefix, '')}/count"), - ( - (cs.starts_with(val_prefix).is_not_null() & cs.starts_with(val_prefix).is_not_nan()) - .sum() - .map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/has_values_count") - ), - val_col.sum().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/sum"), - (val_col**2).sum().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/sum_sqd"), - val_col.min().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/min"), - val_col.max().map_alias(lambda c: f"dynamic/{m}/{c.replace(val_prefix, '')}/max"), - ] - else: - column_cols = [m] - values_cols = [m] - aggs = [ - pl.all().is_not_null().sum().cast(count_type).map_alias(lambda c: f"dynamic/{m}/{c}/count") - ] - - ID_cols = ["measurement_id", "event_id"] - out_dfs[m] = ( - df.select(*ID_cols, *set(column_cols + values_cols)) - .filter(pl.col(m).is_in(allowed_vocab)) - .pivot( - index=ID_cols, - columns=column_cols, - values=values_cols, - aggregate_function=None, - ) - .lazy() - .drop("measurement_id") - .groupby("event_id") - .agg(*aggs) - ) +VALID_AGGREGATIONS = [ + "sum", + "sum_sqd", + "min", + "max", + "value", + "first", + "present", + "count", + "has_values_count", +] - return pl.concat(list(out_dfs.values()), how="align") - -def _summarize_over_window(df: DF_T, window_size: str) -> pl.LazyFrame: - """Apply aggregations to the raw representation over a window size.""" - if isinstance(df, Path): - df = pl.scan_parquet(df) - - def time_aggd_col_alias_fntr(new_agg: str | None = None) -> Callable[[str], str]: - if new_agg is None: - - def f(c: str) -> str: - return "/".join([window_size] + c.split("/")[1:]) - - else: - - def f(c: str) -> str: - return "/".join([window_size] + c.split("/")[1:-1] + [new_agg]) - - return f - - # Columns to convert to counts: - present_indicator_cols = cs.ends_with("/present") - - # Columns to convert to value aggregations: - value_cols = cs.ends_with("/value") - - # Columns to aggregate via other operations - cnt_cols = (cs.ends_with("/count") | cs.ends_with("/has_values_count")).fill_null(0) - - cols_to_sum = cs.ends_with("/sum") | cs.ends_with("/sum_sqd") - cols_to_min = cs.ends_with("/min") - cols_to_max = cs.ends_with("/max") - - if window_size == "FULL": - df = df.groupby("subject_id").agg( - "timestamp", - # present to counts - present_indicator_cols.cumsum().map_alias(time_aggd_col_alias_fntr("count")), - # values to stats - value_cols.is_not_null().cumsum().map_alias(time_aggd_col_alias_fntr("count")), - ( - (value_cols.is_not_null() & value_cols.is_not_nan()) - .cumsum() - .map_alias(time_aggd_col_alias_fntr("has_values_count")) - ), - value_cols.cumsum().map_alias(time_aggd_col_alias_fntr("sum")), - (value_cols**2).cumsum().map_alias(time_aggd_col_alias_fntr("sum_sqd")), - value_cols.cummin().map_alias(time_aggd_col_alias_fntr("min")), - value_cols.cummax().map_alias(time_aggd_col_alias_fntr("max")), - # Raw aggregations - cnt_cols.cumsum().map_alias(time_aggd_col_alias_fntr()), - cols_to_sum.cumsum().map_alias(time_aggd_col_alias_fntr()), - cols_to_min.cummin().map_alias(time_aggd_col_alias_fntr()), - cols_to_max.cummax().map_alias(time_aggd_col_alias_fntr()), - ) - df = df.explode(*[c for c in df.columns if c != "subject_id"]) - else: - df = df.groupby_rolling( - index_column="timestamp", - by="subject_id", - period=window_size, - ).agg( - # present to counts - present_indicator_cols.sum().map_alias(time_aggd_col_alias_fntr("count")), - # values to stats - value_cols.is_not_null().sum().map_alias(time_aggd_col_alias_fntr("count")), - ( - (value_cols.is_not_null() & value_cols.is_not_nan()) - .sum() - .map_alias(time_aggd_col_alias_fntr("has_values_count")) - ), - value_cols.sum().map_alias(time_aggd_col_alias_fntr("sum")), - (value_cols**2).sum().map_alias(time_aggd_col_alias_fntr("sum_sqd")), - value_cols.min().map_alias(time_aggd_col_alias_fntr("min")), - value_cols.max().map_alias(time_aggd_col_alias_fntr("max")), - # Raw aggregations - cnt_cols.sum().map_alias(time_aggd_col_alias_fntr()), - cols_to_sum.sum().map_alias(time_aggd_col_alias_fntr()), - cols_to_min.min().map_alias(time_aggd_col_alias_fntr()), - cols_to_max.max().map_alias(time_aggd_col_alias_fntr()), +def summarize_dynamic_measurements( + ts_columns: list[str], + df: DF_T, +) -> pl.LazyFrame: + """Summarize dynamic measurements for feature columns that are marked as 'dynamic'. + + Args: + - ts_columns (list[str]): List of feature column identifiers that are specifically marked for dynamic + analysis. + - shard_df (DF_T): Data frame from which features will be extracted and summarized. + + Returns: + - pl.LazyFrame: A summarized data frame containing the dynamic features. + + Example: + >>> data = {'patient_id': [1, 1, 1, 2], + ... 'code': ['A', 'A', 'B', 'B'], + ... 'timestamp': ['2021-01-01', '2021-01-01', '2020-01-01', '2021-01-04'], + ... 'numerical_value': [1, 2, 2, 2]} + >>> df = pl.DataFrame(data).lazy() + >>> ts_columns = ['A', 'B'] + >>> code_df, value_df = summarize_dynamic_measurements(ts_columns, df) + >>> code_df.collect() + shape: (4, 4) + ┌────────────┬────────┬────────┬────────────┐ + │ patient_id ┆ code/A ┆ code/B ┆ timestamp │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ u8 ┆ u8 ┆ str │ + ╞════════════╪════════╪════════╪════════════╡ + │ 1 ┆ 1 ┆ 0 ┆ 2021-01-01 │ + │ 1 ┆ 1 ┆ 0 ┆ 2021-01-01 │ + │ 1 ┆ 0 ┆ 1 ┆ 2020-01-01 │ + │ 2 ┆ 0 ┆ 1 ┆ 2021-01-04 │ + └────────────┴────────┴────────┴────────────┘ + >>> value_df.collect() + shape: (3, 4) + ┌────────────┬────────────┬─────────┬─────────┐ + │ patient_id ┆ timestamp ┆ value/A ┆ value/B │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ f64 ┆ f64 │ + ╞════════════╪════════════╪═════════╪═════════╡ + │ 1 ┆ 2021-01-01 ┆ 1.5 ┆ null │ + │ 1 ┆ 2020-01-01 ┆ null ┆ 2.0 │ + │ 2 ┆ 2021-01-04 ┆ null ┆ 2.0 │ + └────────────┴────────────┴─────────┴─────────┘ + """ + + value_df = ( + df.select("patient_id", "timestamp", "code", "numerical_value") + .collect() + .pivot( + index=["patient_id", "timestamp"], + columns=["code"], + values=["numerical_value"], + aggregate_function="mean", # TODO round up counts so they are binary + separator="/", ) - - return add_missing_cols(df, set_count_0_to_null=True) + .lazy() + ) + value_df = value_df.rename(lambda c: f"value/{c}" if c not in ["patient_id", "timestamp"] else c) + code_df = df.drop("numerical_value").collect().to_dummies(columns=["code"], separator="/").lazy() + return code_df, value_df def get_flat_ts_rep( feature_columns: list[str], - **kwargs, + shard_df: DF_T, ) -> pl.LazyFrame: - """Produce raw representation for dynamic data.""" - - return add_missing_cols( - _summarize_dynamic_measurements(feature_columns, **kwargs) - .sort(by=["subject_id", "timestamp"]) - .collect() - .lazy(), - [c for c in feature_columns if c.startswith("dynamic")], - ) - # The above .collect().lazy() shouldn't be necessary but it appears to be for some reason... + """Produce a flat time series representation from a given data frame, focusing on non-static feature + columns. + + This function filters the given data frame for non-static features based on the 'feature_columns' + provided and generates a flat time series representation using these dynamic features. The resulting + data frame includes both codes and values transformed and aggregated appropriately. + + Args: + feature_columns (list[str]): A list of column identifiers that determine which features are considered + for dynamic analysis. + shard_df (DF_T): The data frame containing time-stamped data from which features will be extracted + and summarized. + + Returns: + pl.LazyFrame: A LazyFrame consisting of the processed time series data, combining both code and value + representations. + + Example: + >>> feature_columns = ['A', 'B', 'C', "static/A"] + >>> data = {'patient_id': [1, 1, 1, 2, 2, 2], + ... 'code': ['A', 'A', 'B', 'B', 'C', 'C'], + ... 'timestamp': ['2021-01-01', '2021-01-01', '2020-01-01', '2021-01-04', None, None], + ... 'numerical_value': [1, 2, 2, 2, 3, 4]} + >>> df = pl.DataFrame(data).lazy() + >>> code_df, value_df = get_flat_ts_rep(feature_columns, df) + >>> code_df.collect() + shape: (4, 4) + ┌────────────┬────────┬────────┬────────────┐ + │ patient_id ┆ code/A ┆ code/B ┆ timestamp │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ u8 ┆ u8 ┆ str │ + ╞════════════╪════════╪════════╪════════════╡ + │ 1 ┆ 1 ┆ 0 ┆ 2021-01-01 │ + │ 1 ┆ 1 ┆ 0 ┆ 2021-01-01 │ + │ 1 ┆ 0 ┆ 1 ┆ 2020-01-01 │ + │ 2 ┆ 0 ┆ 1 ┆ 2021-01-04 │ + └────────────┴────────┴────────┴────────────┘ + >>> value_df.collect() + shape: (3, 4) + ┌────────────┬────────────┬─────────┬─────────┐ + │ patient_id ┆ timestamp ┆ value/A ┆ value/B │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ f64 ┆ f64 │ + ╞════════════╪════════════╪═════════╪═════════╡ + │ 1 ┆ 2021-01-01 ┆ 1.5 ┆ null │ + │ 1 ┆ 2020-01-01 ┆ null ┆ 2.0 │ + │ 2 ┆ 2021-01-04 ┆ null ┆ 2.0 │ + └────────────┴────────────┴─────────┴─────────┘ + """ + ts_columns = [c for c in feature_columns if not c.startswith("static")] + ts_shard_df = shard_df.filter(pl.col("timestamp").is_not_null()) + return summarize_dynamic_measurements(ts_columns, ts_shard_df) diff --git a/src/MEDS_tabular_automl/utils.py b/src/MEDS_tabular_automl/utils.py index 85ec597..bb68cad 100644 --- a/src/MEDS_tabular_automl/utils.py +++ b/src/MEDS_tabular_automl/utils.py @@ -225,7 +225,7 @@ def load_meds_data(MEDS_cohort_dir: str) -> Mapping[str, pl.DataFrame]: def setup_environment(cfg: DictConfig): # check output dir - flat_dir = Path(cfg.tabularized_data_dir) / "flat_reps" + flat_dir = Path(cfg.tabularized_data_dir) assert flat_dir.exists() # load MEDS data diff --git a/tests/test_tabularize.py b/tests/test_tabularize.py index 730ee0b..24c8f1c 100644 --- a/tests/test_tabularize.py +++ b/tests/test_tabularize.py @@ -104,11 +104,10 @@ def test_tabularize(): with tempfile.TemporaryDirectory() as d: MEDS_cohort_dir = Path(d) / "MEDS_cohort" - tabularized_data_dir = Path(d) / "cached_reps" + tabularized_data_dir = Path(d) / "flat_reps" # Create the directories MEDS_cohort_dir.mkdir() - tabularized_data_dir.mkdir() # Store MEDS outputs for split, data in MEDS_OUTPUTS.items(): @@ -140,5 +139,24 @@ def test_tabularize(): logger.info("caching flat representation of MEDS data") store_columns(cfg) tabularize_static_data(cfg) + actual_files = [ + (f.parent.stem, f.stem) for f in list(tabularized_data_dir.glob("static/*/*.parquet")) + ] + expected_files = [("train", "1"), ("train", "0"), ("held_out", "0"), ("tuning", "0")] + assert set(actual_files) == set(expected_files) tabularize_ts_data(cfg) + # confirm the time series files exist: + actual_files = [(f.parent.stem, f.stem) for f in list(tabularized_data_dir.glob("ts/*/*.parquet"))] + expected_files = [ + ("train", "1_value"), + ("train", "0_code"), + ("train", "0_value"), + ("train", "1_code"), + ("held_out", "0_code"), + ("held_out", "0_value"), + ("tuning", "0_code"), + ("tuning", "0_value"), + ] + assert set(actual_files) == set(expected_files) + summarize_ts_data_over_windows(cfg)