diff --git a/scripts/identify_columns.py b/scripts/identify_columns.py index 186b6c5..b52eb41 100644 --- a/scripts/identify_columns.py +++ b/scripts/identify_columns.py @@ -5,16 +5,14 @@ from pathlib import Path import hydra +import numpy as np import polars as pl from loguru import logger from omegaconf import DictConfig, OmegaConf +from MEDS_tabular_automl.file_name import FileNameResolver from MEDS_tabular_automl.mapper import wrap as rwlock_wrap -from MEDS_tabular_automl.utils import ( - compute_feature_frequencies, - load_meds_data, - load_tqdm, -) +from MEDS_tabular_automl.utils import compute_feature_frequencies, load_tqdm def store_config_yaml(config_fp: Path, cfg: DictConfig): @@ -73,14 +71,12 @@ def store_columns( """ iter_wrapper = load_tqdm(cfg.tqdm) # create output dir - flat_dir = Path(cfg.tabularized_data_dir) + f_name_resolver = FileNameResolver(cfg) + flat_dir = f_name_resolver.tabularize_dir flat_dir.mkdir(exist_ok=True, parents=True) - # load MEDS data - split_to_fps = load_meds_data(cfg.MEDS_cohort_dir, load_data=False) - # store params in json file - config_fp = flat_dir / "config.yaml" + config_fp = f_name_resolver.get_config_path() store_config_yaml(config_fp, cfg) # 0. Identify Output Columns and Frequencies @@ -96,11 +92,11 @@ def read_fn(in_fp): return pl.scan_parquet(in_fp) # Map: Iterates through shards and caches feature frequencies - feature_freq_fp = flat_dir / "feature_freqs" - feature_freq_fp.mkdir(exist_ok=True) - for shard_fp in iter_wrapper(split_to_fps["train"]): - name = shard_fp.stem - out_fp = feature_freq_fp / f"{name}.json" + train_shards = f_name_resolver.list_meds_files(split="train") + np.random.shuffle(train_shards) + feature_dir = f_name_resolver.tabularize_dir + for shard_fp in iter_wrapper(train_shards): + out_fp = feature_dir / "identify_train_columns" / f"{shard_fp.stem}.json" rwlock_wrap( shard_fp, out_fp, @@ -123,16 +119,16 @@ def compute_fn(feature_freq_list): def write_fn(data, out_fp): feature_freqs, feature_columns = data - json.dump(feature_columns, open(out_fp / "feature_columns.json", "w")) - json.dump(feature_freqs, open(flat_dir / "feature_freqs.json", "w")) + json.dump(feature_columns, open(f_name_resolver.get_feature_columns_fp(), "w")) + json.dump(feature_freqs, open(f_name_resolver.get_feature_freqs_fp(), "w")) - def read_fn(in_fp): - files = list(in_fp.glob("*.json")) + def read_fn(feature_dir): + files = list(feature_dir.glob("*.json")) return [json.load(open(fp)) for fp in files] rwlock_wrap( - feature_freq_fp, - flat_dir, + feature_dir / "identify_train_columns", + feature_dir, read_fn, write_fn, compute_fn, diff --git a/scripts/summarize_over_windows.py b/scripts/summarize_over_windows.py index 66a4c71..fb9f4c5 100644 --- a/scripts/summarize_over_windows.py +++ b/scripts/summarize_over_windows.py @@ -1,26 +1,20 @@ #!/usr/bin/env python """Aggregates time-series data for feature columns across different window sizes.""" -import os +import json +from itertools import product import hydra +import numpy as np import polars as pl from loguru import logger from omegaconf import DictConfig +from MEDS_tabular_automl.file_name import FileNameResolver from MEDS_tabular_automl.generate_summarized_reps import generate_summary from MEDS_tabular_automl.generate_ts_features import get_flat_ts_rep from MEDS_tabular_automl.mapper import wrap as rwlock_wrap -from MEDS_tabular_automl.utils import setup_environment, write_df - - -def hydra_loguru_init() -> None: - """Adds loguru output to the logs that hydra scrapes. - - Must be called from a hydra main! - """ - hydra_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir - logger.add(os.path.join(hydra_path, "main.log")) +from MEDS_tabular_automl.utils import hydra_loguru_init, load_tqdm, write_df @hydra.main(version_base=None, config_path="../configs", config_name="tabularize") @@ -53,58 +47,60 @@ def summarize_ts_data_over_windows( FileNotFoundError: If specified directories or files in the configuration are not found. ValueError: If required columns like 'code' or 'value' are missing in the data files. """ + iter_wrapper = load_tqdm(cfg.tqdm) if not cfg.test: hydra_loguru_init() - flat_dir, split_to_fps, feature_columns = setup_environment(cfg, load_data=False) + f_name_resolver = FileNameResolver(cfg) # Produce ts representation - ts_subdir = flat_dir / "ts" - - for sp, shard_fps in split_to_fps.items(): - sp_dir = ts_subdir / sp - - for i, shard_fp in enumerate(shard_fps): - for window_size in cfg.window_sizes: - for agg in cfg.aggs: - pivot_fp = sp_dir / window_size / agg / f"{i}.pkl" - if pivot_fp.exists() and not cfg.do_overwrite: - raise FileExistsError( - f"do_overwrite is {cfg.do_overwrite} and {pivot_fp.exists()} exists!" - ) - - def read_fn(fp): - return pl.scan_parquet(fp) - - def compute_fn(shard_df): - # Load Sparse DataFrame - pivot_df = get_flat_ts_rep( - feature_columns=feature_columns, - shard_df=shard_df, - ) - - # Summarize data -- applying aggregations on various window sizes - summary_df = generate_summary( - feature_columns, - pivot_df, - window_size, - agg, - ) - assert summary_df.shape[1] > 2, "No data found in the summarized dataframe" - - logger.info("Writing pivot file") - return summary_df - - def write_fn(out_df, out_fp): - write_df(out_df, out_fp, do_overwrite=cfg.do_overwrite) - - rwlock_wrap( - shard_fp, - pivot_fp, - read_fn, - write_fn, - compute_fn, - do_overwrite=cfg.do_overwrite, - do_return=False, - ) + meds_shard_fps = f_name_resolver.list_meds_files() + feature_columns = json.load(open(f_name_resolver.get_feature_columns_fp())) + + # shuffle tasks + tabularization_tasks = list(product(meds_shard_fps, cfg.window_sizes, cfg.aggs)) + np.random.shuffle(tabularization_tasks) + + # iterate through them + for shard_fp, window_size, agg in iter_wrapper(tabularization_tasks): + shard_num = shard_fp.stem + split = shard_fp.parent.stem + assert split in ["train", "held_out", "tuning"], f"Invalid split {split}" + ts_fp = f_name_resolver.get_flat_ts_rep(split, shard_num, window_size, agg) + if ts_fp.exists() and not cfg.do_overwrite: + raise FileExistsError(f"do_overwrite is {cfg.do_overwrite} and {ts_fp.exists()} exists!") + + def read_fn(fp): + return pl.scan_parquet(fp) + + def compute_fn(shard_df): + # Load Sparse DataFrame + index_df, sparse_matrix = get_flat_ts_rep(feature_columns, shard_df) + + # Summarize data -- applying aggregations on a specific window size + aggregation combination + summary_df = generate_summary( + feature_columns, + index_df, + sparse_matrix, + window_size, + agg, + ) + assert summary_df.shape[1] > 2, "No data found in the summarized dataframe" + + logger.info("Writing pivot file") + return summary_df + + def write_fn(out_matrix, out_fp): + coo_matrix = out_matrix.tocoo() + write_df(coo_matrix, out_fp, do_overwrite=cfg.do_overwrite) + + rwlock_wrap( + shard_fp, + ts_fp, + read_fn, + write_fn, + compute_fn, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) if __name__ == "__main__": diff --git a/scripts/tabularize_static.py b/scripts/tabularize_static.py index 8f19ae6..d5ba698 100644 --- a/scripts/tabularize_static.py +++ b/scripts/tabularize_static.py @@ -1,15 +1,19 @@ #!/usr/bin/env python """Tabularizes static data in MEDS format into tabular representations.""" +import json +from itertools import product from pathlib import Path import hydra +import numpy as np import polars as pl from omegaconf import DictConfig, OmegaConf +from MEDS_tabular_automl.file_name import FileNameResolver from MEDS_tabular_automl.generate_static_features import get_flat_static_rep from MEDS_tabular_automl.mapper import wrap as rwlock_wrap -from MEDS_tabular_automl.utils import setup_environment, write_df +from MEDS_tabular_automl.utils import hydra_loguru_init, load_tqdm, write_df pl.enable_string_cache() @@ -96,44 +100,46 @@ def tabularize_static_data( .. _link: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.groupby_rolling.html # noqa: E501 """ - flat_dir, split_to_fp, feature_columns = setup_environment(cfg, load_data=False) - - # Produce static representation - static_subdir = flat_dir / "static" - - static_dfs = {} - for sp, shard_fps in split_to_fp.items(): - static_dfs[sp] = [] - sp_dir = static_subdir / sp - - for i, shard_fp in enumerate(shard_fps): - fp = sp_dir / f"{i}.parquet" - static_dfs[sp].append(fp) - if fp.exists() and not cfg.do_overwrite: - raise FileExistsError(f"do_overwrite is {cfg.do_overwrite} and {fp} exists!") - - def read_fn(in_fp): - return pl.scan_parquet(in_fp) - - def compute_fn(shard_df): - return get_flat_static_rep( - feature_columns=feature_columns, - shard_df=shard_df, - ) - - def write_fn(data, out_df): - write_df(data, out_df, do_overwrite=cfg.do_overwrite) - - rwlock_wrap( - shard_fp, - fp, - read_fn, - write_fn, - compute_fn, - do_overwrite=cfg.do_overwrite, - do_return=False, + iter_wrapper = load_tqdm(cfg.tqdm) + if not cfg.test: + hydra_loguru_init() + f_name_resolver = FileNameResolver(cfg) + # Produce ts representation + meds_shard_fps = f_name_resolver.list_meds_files() + # f_name_resolver.get_meds_dir() + feature_columns = json.load(open(f_name_resolver.get_feature_columns_fp())) + + # shuffle tasks + tabularization_tasks = list(product(meds_shard_fps, cfg.window_sizes, cfg.aggs)) + np.random.shuffle(tabularization_tasks) + + for shard_fp in iter_wrapper(meds_shard_fps): + static_fp = f_name_resolver.get_flat_static_rep(shard_fp.parent.stem, shard_fp.stem) + if static_fp.exists() and not cfg.do_overwrite: + raise FileExistsError(f"do_overwrite is {cfg.do_overwrite} and {static_fp} exists!") + + def read_fn(in_fp): + return pl.scan_parquet(in_fp) + + def compute_fn(shard_df): + return get_flat_static_rep( + feature_columns=feature_columns, + shard_df=shard_df, ) + def write_fn(data, out_df): + write_df(data, out_df, do_overwrite=cfg.do_overwrite) + + rwlock_wrap( + shard_fp, + static_fp, + read_fn, + write_fn, + compute_fn, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) + if __name__ == "__main__": tabularize_static_data() diff --git a/scripts/tabularize_ts.py b/scripts/tabularize_ts.py deleted file mode 100644 index ae39595..0000000 --- a/scripts/tabularize_ts.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python -"""Tabularizes time-series data in MEDS format into tabular representations.""" - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig - -from MEDS_tabular_automl.generate_ts_features import get_flat_ts_rep -from MEDS_tabular_automl.mapper import wrap as rwlock_wrap -from MEDS_tabular_automl.utils import load_tqdm, setup_environment, write_df - - -@hydra.main(version_base=None, config_path="../configs", config_name="tabularize") -def tabularize_ts_data( - cfg: DictConfig, -): - """Processes a medical dataset to generates and stores flat representatiosn of time-series data. - - This function handles MEDS format data and pivots tables to create two types of data files - with patient_id and timestamp indexes: - code data: containing a column for every code and 1 and 0 values indicating presence - value data: containing a column for every code which the numerical value observed. - - Args: - cfg: configuration dictionary containing the necessary parameters for tabularizing the data. - """ - iter_wrapper = load_tqdm(cfg.tqdm) - flat_dir, split_to_fp, feature_columns = setup_environment(cfg, load_data=False) - - # Produce ts representation - ts_subdir = flat_dir / "ts" - - for sp, shard_fps in split_to_fp.items(): - sp_dir = ts_subdir / sp - - for i, shard_fp in enumerate(iter_wrapper(shard_fps)): - out_fp = sp_dir / f"{i}.pkl" - - def read_fn(in_fp): - return pl.scan_parquet(in_fp) - - def compute_fn(shard_df): - return get_flat_ts_rep( - feature_columns=feature_columns, - shard_df=shard_df, - ) - - def write_fn(data, out_df): - write_df(data, out_df, do_overwrite=cfg.do_overwrite) - - rwlock_wrap( - shard_fp, - out_fp, - read_fn, - write_fn, - compute_fn, - do_overwrite=cfg.do_overwrite, - do_return=False, - ) - logger.info("Generated TS flat representations.") - - -if __name__ == "__main__": - tabularize_ts_data() diff --git a/src/MEDS_tabular_automl/file_name.py b/src/MEDS_tabular_automl/file_name.py new file mode 100644 index 0000000..7c5b1cf --- /dev/null +++ b/src/MEDS_tabular_automl/file_name.py @@ -0,0 +1,72 @@ +"""Help functions for getting file names and paths for MEDS tabular automl tasks.""" +from pathlib import Path + +from omegaconf import DictConfig + + +class FileNameResolver: + def __init__(self, cfg: DictConfig): + self.cfg = cfg + self.meds_dir = Path(cfg.MEDS_cohort_dir) + self.tabularize_dir = Path(cfg.tabularized_data_dir) + + def get_meds_dir(self): + return self.meds_dir / "final_cohort" + + def get_static_dir(self): + return self.tabularize_dir / "static" + + def get_ts_dir(self): + return self.tabularize_dir / "ts" + + def get_sparse_dir(self): + return self.tabularize_dir / "sparse" + + def get_feature_columns_fp(self): + return self.tabularize_dir / "feature_columns.json" + + def get_feature_freqs_fp(self): + return self.tabularize_dir / "feature_freqs.json" + + def get_config_path(self): + return self.tabularize_dir / "config.yaml" + + def get_meds_shard(self, shard_num: int): + # Given a shard number, return the MEDS format data + return self.get_meds_dir() / f"{shard_num}.parquet" + + def get_flat_static_rep(self, split: str, shard_num: int): + # Given a shard number, returns the static representation path + return self.get_static_dir() / split / f"{shard_num}.parquet" + + def get_flat_ts_rep(self, split: str, shard_num: int, window_size: int, agg: str): + # Given a shard number, returns the time series representation path + return self.get_ts_dir() / split / f"{shard_num}" / f"{window_size}" / f"{agg}.npz" + + def get_flat_sparse_rep(self, split: str, shard_num: int, window_size: int, agg: str): + # Given a shard number, returns the sparse representation path + return self.get_sparse_dir() / split / f"{shard_num}" / f"{window_size}" / f"{agg}.npz" + + def list_meds_files(self, split=None): + # List all MEDS files + if split: + return sorted(list(self.get_meds_dir().glob(f"{split}/*.parquet"))) + return sorted(list(self.get_meds_dir().glob("*/*.parquet"))) + + def list_static_files(self, split=None): + # List all static files + if split: + return sorted(list(self.get_static_dir().glob(f"{split}/*.parquet"))) + return sorted(list(self.get_static_dir().glob("*/*.parquet"))) + + def list_ts_files(self, split=None): + # List all time series files + if split: + return sorted(list(self.get_ts_dir().glob(f"{split}/*/*/*/*.npz"))) + return sorted(list(self.get_ts_dir().glob("*/*/*/*/*.npz"))) + + def list_sparse_files(self, split=None): + # List all sparse files + if split: + return sorted(list(self.get_sparse_dir().glob(f"{split}/*/*.npz"))) + return sorted(list(self.get_sparse_dir().glob("*/*/*.npz"))) diff --git a/src/MEDS_tabular_automl/generate_summarized_reps.py b/src/MEDS_tabular_automl/generate_summarized_reps.py index 33812eb..c77cb14 100644 --- a/src/MEDS_tabular_automl/generate_summarized_reps.py +++ b/src/MEDS_tabular_automl/generate_summarized_reps.py @@ -127,8 +127,8 @@ def sparse_rolling(df, sparse_matrix, timedelta, agg): def get_rolling_window_indicies(index_df, window_size): """Get the indices for the rolling windows.""" if window_size == "full": - newest_date = df.select(pl.col("timestamp")).max().collect().item() - oldest_date = df.select(pl.col("timestamp")).min().collect().item() + newest_date = index_df.select(pl.col("timestamp")).max().collect().item() + oldest_date = index_df.select(pl.col("timestamp")).min().collect().item() timedelta = newest_date - oldest_date + pd.Timedelta(days=1) else: timedelta = pd.Timedelta(window_size) @@ -273,7 +273,7 @@ def _generate_summary( """ if agg not in VALID_AGGREGATIONS: raise ValueError(f"Invalid aggregation: {agg}. Valid options are: {VALID_AGGREGATIONS}") - out_matrix = compute_agg(index_df, sparse_matrix, window_size, agg, use_tqdm=use_tqdm) + out_matrix = compute_agg(index_df, matrix, window_size, agg, use_tqdm=use_tqdm) return out_matrix diff --git a/src/MEDS_tabular_automl/utils.py b/src/MEDS_tabular_automl/utils.py index 92027c4..2da18b0 100644 --- a/src/MEDS_tabular_automl/utils.py +++ b/src/MEDS_tabular_automl/utils.py @@ -6,22 +6,33 @@ DF_T: This defines the type of internal dataframes -- e.g. polars DataFrames. """ import json +import os from collections.abc import Mapping from pathlib import Path +import hydra import numpy as np -import pandas as pd import polars as pl import polars.selectors as cs import yaml from loguru import logger from omegaconf import DictConfig, OmegaConf +from scipy.sparse import coo_array DF_T = pl.LazyFrame WRITE_USE_PYARROW = True ROW_IDX_NAME = "__row_idx" +def hydra_loguru_init() -> None: + """Adds loguru output to the logs that hydra scrapes. + + Must be called from a hydra main! + """ + hydra_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + logger.add(os.path.join(hydra_path, "main.log")) + + def load_tqdm(use_tqdm): if use_tqdm: from tqdm import tqdm @@ -42,7 +53,28 @@ def parse_static_feature_column(c: str) -> tuple[str, str, str, str]: return ("/".join(parts[:-2]), parts[-2], parts[-1]) -def write_df(df: DF_T, fp: Path, **kwargs): +def array_to_sparse_matrix(array: np.ndarray, shape: tuple[int, int]): + assert array.shape[0] == 3 + data, row, col = array + return coo_array((data, (row, col)), shape=shape) + + +def sparse_matrix_to_array(coo_matrix: coo_array): + return np.array([coo_matrix.data, coo_matrix.row, coo_matrix.col]), coo_matrix.shape + + +def store_matrix(coo_matrix: coo_array, fp_path: Path): + array, shape = sparse_matrix_to_array(coo_matrix) + np.savez(fp_path, array=array, shape=shape) + + +def load_matrix(fp_path: Path): + npzfile = np.load(fp_path) + array, shape = npzfile["array"], npzfile["shape"] + return array_to_sparse_matrix(array, shape) + + +def write_df(df: coo_array, fp: Path, **kwargs): """Write shard to disk.""" do_overwrite = kwargs.get("do_overwrite", False) @@ -55,16 +87,10 @@ def write_df(df: DF_T, fp: Path, **kwargs): df.collect().write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) elif isinstance(df, pl.DataFrame): df.write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) - elif isinstance(df, pd.DataFrame): - if not all(df.columns[:2] == ["patient_id", "timestamp"]): - raise ValueError( - f"Expected DataFrame to have columns ['patient_id', 'timestamp'], got {df.columns[:2]}" - ) - df.to_pickle(fp) - elif isinstance(df, np.matrix): - np.save(fp, df) + elif isinstance(df, coo_array): + store_matrix(df, fp) else: - raise ValueError(f"Unsupported type for df: {type(df)}") + raise TypeError(f"Unsupported type for df: {type(df)}") def get_static_col_dtype(col: str) -> pl.DataType: diff --git a/tests/test_tabularize.py b/tests/test_tabularize.py index ed6988c..5755c98 100644 --- a/tests/test_tabularize.py +++ b/tests/test_tabularize.py @@ -3,21 +3,19 @@ root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import json -import shutil import tempfile from io import StringIO from pathlib import Path -import pandas as pd import polars as pl from hydra import compose, initialize from loguru import logger +from MEDS_tabular_automl.file_name import FileNameResolver +from MEDS_tabular_automl.utils import load_matrix from scripts.identify_columns import store_columns from scripts.summarize_over_windows import summarize_ts_data_over_windows -from scripts.tabularize_merge import merge_data from scripts.tabularize_static import tabularize_static_data -from scripts.tabularize_ts import tabularize_ts_data SPLITS_JSON = """{"train/0": [239684, 1195293], "train/1": [68729, 814703], "tuning/0": [754281], "held_out/0": [1500733]}""" # noqa: E501 @@ -104,80 +102,64 @@ } SUMMARIZE_EXPECTED_FILES = [ - "train/365d/value/sum/0.pkl", - "train/365d/value/sum/1.pkl", - "train/365d/code/count/0.pkl", - "train/365d/code/count/1.pkl", - "train/full/value/sum/0.pkl", - "train/full/value/sum/1.pkl", - "train/full/code/count/0.pkl", - "train/full/code/count/1.pkl", - "train/30d/value/sum/0.pkl", - "train/30d/value/sum/1.pkl", - "train/30d/code/count/0.pkl", - "train/30d/code/count/1.pkl", - "held_out/365d/value/sum/0.pkl", - "held_out/365d/code/count/0.pkl", - "held_out/full/value/sum/0.pkl", - "held_out/full/code/count/0.pkl", - "held_out/30d/value/sum/0.pkl", - "held_out/30d/code/count/0.pkl", - "tuning/365d/value/sum/0.pkl", - "tuning/365d/code/count/0.pkl", - "tuning/full/value/sum/0.pkl", - "tuning/full/code/count/0.pkl", - "tuning/30d/value/sum/0.pkl", - "tuning/30d/code/count/0.pkl", + "train/1/365d/value/sum.npz", + "train/1/365d/code/count.npz", + "train/1/full/value/sum.npz", + "train/1/full/code/count.npz", + "train/1/30d/value/sum.npz", + "train/1/30d/code/count.npz", + "train/0/365d/value/sum.npz", + "train/0/365d/code/count.npz", + "train/0/full/value/sum.npz", + "train/0/full/code/count.npz", + "train/0/30d/value/sum.npz", + "train/0/30d/code/count.npz", + "held_out/0/365d/value/sum.npz", + "held_out/0/365d/code/count.npz", + "held_out/0/full/value/sum.npz", + "held_out/0/full/code/count.npz", + "held_out/0/30d/value/sum.npz", + "held_out/0/30d/code/count.npz", + "tuning/0/365d/value/sum.npz", + "tuning/0/365d/code/count.npz", + "tuning/0/full/value/sum.npz", + "tuning/0/full/code/count.npz", + "tuning/0/30d/value/sum.npz", + "tuning/0/30d/code/count.npz", ] MERGE_EXPECTED_FILES = [ - "train/365d/value/sum/0.npy", - "train/365d/value/sum/1.npy", - "train/365d/code/count/0.npy", - "train/365d/code/count/1.npy", - "train/full/value/sum/0.npy", - "train/full/value/sum/1.npy", - "train/full/code/count/0.npy", - "train/full/code/count/1.npy", - "train/30d/value/sum/0.npy", - "train/30d/value/sum/1.npy", - "train/30d/code/count/0.npy", - "train/30d/code/count/1.npy", - "held_out/365d/value/sum/0.npy", - "held_out/365d/code/count/0.npy", - "held_out/full/value/sum/0.npy", - "held_out/full/code/count/0.npy", - "held_out/30d/value/sum/0.npy", - "held_out/30d/code/count/0.npy", - "tuning/365d/value/sum/0.npy", - "tuning/365d/code/count/0.npy", - "tuning/full/value/sum/0.npy", - "tuning/full/code/count/0.npy", - "tuning/30d/value/sum/0.npy", - "tuning/30d/code/count/0.npy", + "train/365d/value/sum/0.npz", + "train/365d/value/sum/1.npz", + "train/365d/code/count/0.npz", + "train/365d/code/count/1.npz", + "train/full/value/sum/0.npz", + "train/full/value/sum/1.npz", + "train/full/code/count/0.npz", + "train/full/code/count/1.npz", + "train/30d/value/sum/0.npz", + "train/30d/value/sum/1.npz", + "train/30d/code/count/0.npz", + "train/30d/code/count/1.npz", + "held_out/365d/value/sum/0.npz", + "held_out/365d/code/count/0.npz", + "held_out/full/value/sum/0.npz", + "held_out/full/code/count/0.npz", + "held_out/30d/value/sum/0.npz", + "held_out/30d/code/count/0.npz", + "tuning/365d/value/sum/0.npz", + "tuning/365d/code/count/0.npz", + "tuning/full/value/sum/0.npz", + "tuning/full/code/count/0.npz", + "tuning/30d/value/sum/0.npz", + "tuning/30d/code/count/0.npz", ] def test_tabularize(): with tempfile.TemporaryDirectory() as d: - MEDS_cohort_dir = Path(d) / "MEDS_cohort" - tabularized_data_dir = Path(d) / "flat_reps" - - # Create the directories - MEDS_cohort_dir.mkdir() - - # Store MEDS outputs - for split, data in MEDS_OUTPUTS.items(): - file_path = MEDS_cohort_dir / f"{split}.parquet" - file_path.parent.mkdir(exist_ok=True) - df = pl.read_csv(StringIO(data)) - df.with_columns(pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S.%f")).write_parquet( - file_path - ) - - split_json = json.load(StringIO(SPLITS_JSON)) - splits_fp = MEDS_cohort_dir / "splits.json" - json.dump(split_json, splits_fp.open("w")) + MEDS_cohort_dir = Path(d) / "processed" + tabularized_data_dir = Path(d) / "processed" / "tabularize" tabularize_config_kwargs = { "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), @@ -198,50 +180,61 @@ def test_tabularize(): with initialize(version_base=None, config_path="../configs/"): # path to config.yaml overrides = [f"{k}={v}" for k, v in tabularize_config_kwargs.items()] cfg = compose(config_name="tabularize", overrides=overrides) # config.yaml + + f_name_resolver = FileNameResolver(cfg) + + # Create the directories + (MEDS_cohort_dir / "final_cohort").mkdir(parents=True, exist_ok=True) + + # Store MEDS outputs + for split, data in MEDS_OUTPUTS.items(): + file_path = MEDS_cohort_dir / "final_cohort" / f"{split}.parquet" + file_path.parent.mkdir(exist_ok=True) + df = pl.read_csv(StringIO(data)) + df.with_columns(pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S.%f")).write_parquet( + file_path + ) + + # Check the files are not empty + meds_files = f_name_resolver.list_meds_files() + assert len(meds_files) == 4, "MEDS Data Files Should be 4!" + for f in meds_files: + assert pl.read_parquet(f).shape[0] > 0, "MEDS Data Tabular Dataframe Should not be Empty!" + + split_json = json.load(StringIO(SPLITS_JSON)) + splits_fp = MEDS_cohort_dir / "splits.json" + json.dump(split_json, splits_fp.open("w")) logger.info("caching flat representation of MEDS data") store_columns(cfg) assert (tabularized_data_dir / "config.yaml").is_file() assert (tabularized_data_dir / "feature_columns.json").is_file() assert (tabularized_data_dir / "feature_freqs.json").is_file() tabularize_static_data(cfg) - actual_files = [ - (f.parent.stem, f.stem) for f in list(tabularized_data_dir.glob("static/*/*.parquet")) - ] + actual_files = [(f.parent.stem, f.stem) for f in f_name_resolver.list_static_files()] expected_files = [("train", "1"), ("train", "0"), ("held_out", "0"), ("tuning", "0")] + f_name_resolver.get_static_dir() assert set(actual_files) == set(expected_files) # Check the files are not empty for f in list(tabularized_data_dir.glob("static/*/*.parquet")): assert pl.read_parquet(f).shape[0] > 0, "Static Data Tabular Dataframe Should not be Empty!" - tabularize_ts_data(cfg) - # confirm the time series files exist: - actual_files = [(f.parent.stem, f.stem) for f in list(tabularized_data_dir.glob("ts/*/*.pkl"))] - expected_files = [ - ("train", "1"), - ("train", "0"), - ("held_out", "0"), - ("tuning", "0"), - ] - assert set(actual_files) == set(expected_files) - for f in list(tabularized_data_dir.glob("ts/*/*.pkl")): - assert pd.read_pickle(f).shape[0] > 0, "Time-Series Tabular Dataframe Should not be Empty!" - shutil.rmtree(tabularized_data_dir / "ts") - summarize_ts_data_over_windows(cfg) # confirm summary files exist: - output_files = list(tabularized_data_dir.glob("ts/*/*/*/*/*.pkl")) + output_files = list(tabularized_data_dir.glob("ts/*/*/*/*/*.npz")) + f_name_resolver.list_ts_files() actual_files = [str(Path(*f.parts[-5:])) for f in output_files] assert set(actual_files) == set(SUMMARIZE_EXPECTED_FILES) for f in output_files: - df = pd.read_pickle(f) - assert df.shape[0] > 0 - - merge_data(cfg) - output_files = list(tabularized_data_dir.glob("sparse/*/*/*/*/*.npy")) - actual_files = [str(Path(*f.parts[-5:])) for f in output_files] - assert set(actual_files) == set(MERGE_EXPECTED_FILES) + sparse_array = load_matrix(f) + assert sparse_array.shape[0] > 0 + assert sparse_array.shape[1] > 0 + + # merge_data(cfg) + # output_files = list(tabularized_data_dir.glob("sparse/*/*/*/*/*.npz")) + # actual_files = [str(Path(*f.parts[-5:])) for f in output_files] + # assert set(actual_files) == set(MERGE_EXPECTED_FILES) # model_dir = Path(d) / "save_model" # xgboost_config_kwargs = {