diff --git a/src/MEDS_tabular_automl/configs/launch_xgboost.yaml b/src/MEDS_tabular_automl/configs/launch_xgboost.yaml index 93118e5..4fc1383 100644 --- a/src/MEDS_tabular_automl/configs/launch_xgboost.yaml +++ b/src/MEDS_tabular_automl/configs/launch_xgboost.yaml @@ -11,7 +11,7 @@ task_name: task # Task cached data dir input_dir: ${output_cohort_dir}/${task_name}/task_cache # Directory with task labels -input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/ +input_label_dir: ${output_cohort_dir}/${task_name}/labels/ # Where to output the model and cached data model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S} output_filepath: ${model_dir}/model_metadata.json diff --git a/src/MEDS_tabular_automl/configs/tabularization.yaml b/src/MEDS_tabular_automl/configs/tabularization.yaml index bcb0e78..cf03d63 100644 --- a/src/MEDS_tabular_automl/configs/tabularization.yaml +++ b/src/MEDS_tabular_automl/configs/tabularization.yaml @@ -6,7 +6,7 @@ defaults: # Raw data # Where the code metadata is stored input_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet -input_dir: ${output_cohort_data}/data +input_dir: ${output_cohort_dir}/data output_dir: ${output_cohort_dir}/tabularize name: tabularization diff --git a/src/MEDS_tabular_automl/configs/task_specific_caching.yaml b/src/MEDS_tabular_automl/configs/task_specific_caching.yaml index ad8ed99..c002dfc 100644 --- a/src/MEDS_tabular_automl/configs/task_specific_caching.yaml +++ b/src/MEDS_tabular_automl/configs/task_specific_caching.yaml @@ -10,5 +10,8 @@ input_dir: ${output_cohort_dir}/tabularize input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels # Where to output the task specific tabularized data output_dir: ${output_cohort_dir}/${task_name}/task_cache +output_label_dir: ${output_cohort_dir}/${task_name}/labels + +label_column: "boolean_value" name: task_specific_caching diff --git a/src/MEDS_tabular_automl/scripts/cache_task.py b/src/MEDS_tabular_automl/scripts/cache_task.py index 62df9b8..3a635ac 100644 --- a/src/MEDS_tabular_automl/scripts/cache_task.py +++ b/src/MEDS_tabular_automl/scripts/cache_task.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Aggregates time-series data for feature columns across different window sizes.""" +from functools import partial from importlib.resources import files from pathlib import Path @@ -10,6 +11,7 @@ import scipy.sparse as sp from omegaconf import DictConfig +from ..describe_codes import filter_parquet, get_feature_columns from ..file_name import list_subdir_files from ..mapper import wrap as rwlock_wrap from ..utils import ( @@ -17,7 +19,9 @@ STATIC_CODE_AGGREGATION, STATIC_VALUE_AGGREGATION, VALUE_AGGREGATIONS, + get_events_df, get_shard_prefix, + get_unique_time_events_df, hydra_loguru_init, load_matrix, load_tqdm, @@ -37,6 +41,10 @@ ] +def write_lazyframe(df: pl.LazyFrame, fp: Path): + df.collect().write_parquet(fp, use_pyarrow=True) + + def generate_row_cached_matrix(matrix: sp.coo_array, label_df: pl.LazyFrame) -> sp.coo_array: """Generates row-cached matrix for a given matrix and label DataFrame. @@ -80,31 +88,44 @@ def main(cfg: DictConfig): tabularization_tasks = list_subdir_files(cfg.input_dir, "npz") np.random.shuffle(tabularization_tasks) + label_dir = Path(cfg.input_label_dir) + label_df = pl.scan_parquet(label_dir / "**/*.parquet").rename({"prediction_time": "time"}) + + feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp) + # iterate through them for data_fp in iter_wrapper(tabularization_tasks): # parse as time series agg split, shard_num, window_size, code_type, agg_name = Path(data_fp).with_suffix("").parts[-5:] - label_fp = Path(cfg.input_label_dir) / split / f"{shard_num}.parquet" - out_fp = (Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, data_fp)).with_suffix(".npz") - assert label_fp.exists(), f"Output file {label_fp} does not exist." - def read_fn(fps): - matrix_fp, label_fp = fps - return load_matrix(matrix_fp), pl.scan_parquet(label_fp) + raw_data_fp = Path(cfg.output_cohort_dir) / "data" / split / f"{shard_num}.parquet" + raw_data_df = filter_parquet(raw_data_fp, cfg.tabularization._resolved_codes) + raw_data_df = ( + get_unique_time_events_df(get_events_df(raw_data_df, feature_columns)) + .with_row_index("event_id") + .select("patient_id", "time", "event_id") + ) + shard_label_df = label_df.join_asof(other=raw_data_df, by="patient_id", on="time") - def compute_fn(shard_dfs): - matrix, label_df = shard_dfs - cache_matrix = generate_row_cached_matrix(matrix, label_df) - return cache_matrix + shard_label_fp = Path(cfg.output_label_dir) / split / f"{shard_num}.parquet" + rwlock_wrap( + raw_data_fp, + shard_label_fp, + pl.scan_parquet, + write_lazyframe, + lambda df: shard_label_df, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) - def write_fn(cache_matrix, out_fp): - write_df(cache_matrix, out_fp, do_overwrite=cfg.do_overwrite) + out_fp = (Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, data_fp)).with_suffix(".npz") + compute_fn = partial(generate_row_cached_matrix, label_df=shard_label_df) + write_fn = partial(write_df, do_overwrite=cfg.do_overwrite) - in_fps = [data_fp, label_fp] rwlock_wrap( - in_fps, + data_fp, out_fp, - read_fn, + load_matrix, write_fn, compute_fn, do_overwrite=cfg.do_overwrite, diff --git a/tests/test_integration.py b/tests/test_integration.py index b53f308..d22eac5 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -74,11 +74,13 @@ def test_integration(): for split, data in MEDS_OUTPUTS.items(): file_path = output_cohort_dir / "data" / f"{split}.parquet" file_path.parent.mkdir(exist_ok=True) - df = pl.read_csv(StringIO(data)) - df.with_columns(pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet(file_path) + df = pl.read_csv(StringIO(data)).with_columns( + pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f") + ) + df.write_parquet(file_path) all_data.append(df) - all_data = pl.concat(all_data, how="diagnoal_relaxed") + all_data = pl.concat(all_data, how="diagonal_relaxed").sort(by=["patient_id", "time"]) # Check the files are not empty meds_files = list_subdir_files(Path(cfg.input_dir), "parquet") @@ -210,7 +212,7 @@ def test_integration(): overrides = [f"{k}={v}" for k, v in cache_config.items()] cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml - df = get_unique_time_events_df(get_events_df(all_data, feature_columns)).collect() + df = get_unique_time_events_df(get_events_df(all_data.lazy(), feature_columns)).collect() pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]]) df = df.with_columns(pl.Series(name="boolean_value", values=pseudo_labels)) df = df.select("patient_id", pl.col("time").alias("prediction_time"), "boolean_value") diff --git a/tests/test_tabularize.py b/tests/test_tabularize.py index 0c8fd92..130721c 100644 --- a/tests/test_tabularize.py +++ b/tests/test_tabularize.py @@ -177,11 +177,13 @@ def test_tabularize(): for split, data in MEDS_OUTPUTS.items(): file_path = output_cohort_dir / "data" / f"{split}.parquet" file_path.parent.mkdir(exist_ok=True) - df = pl.read_csv(StringIO(data)) - df.with_columns(pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet(file_path) + df = pl.read_csv(StringIO(data)).with_columns( + pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f") + ) + df.write_parquet(file_path) all_data.append(df) - all_data = pl.concat(all_data, how="diagnoal_relaxed") + all_data = pl.concat(all_data, how="diagonal_relaxed").sort(by=["patient_id", "time"]) # Check the files are not empty meds_files = list_subdir_files(Path(cfg.input_dir), "parquet") @@ -297,7 +299,7 @@ def test_tabularize(): cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml # Create fake labels - df = get_unique_time_events_df(get_events_df(all_data, feature_columns)).collect() + df = get_unique_time_events_df(get_events_df(all_data.lazy(), feature_columns)).collect() pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]]) df = df.with_columns(pl.Series(name="boolean_value", values=pseudo_labels)) df = df.select("patient_id", pl.col("time").alias("prediction_time"), "boolean_value")