diff --git a/README.md b/README.md index 3ce789f..838081a 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,8 @@ This repository consists of two key pieces: ## Quick Start -To use MEDS-Tab, install the dependencies following commands below: +To use MEDS-Tab, install the dependencies following commands below. Note that this version of MEDS-Tab is +compatible with [MEDS v0.3](https://github.com/Medical-Event-Data-Standard/meds/releases/tag/0.3.0) **Pip Install** @@ -44,10 +45,10 @@ pip install . ## Scripts and Examples -For an end-to-end example over MIMIC-IV, see the [MIMIC-IV companion repository](https://github.com/mmcdermott/MEDS_TAB_MIMIC_IV). -For an end-to-end example over Philips eICU, see the [eICU companion repository](https://github.com/mmcdermott/MEDS_TAB_EICU). +For an end to end example, including re-sharding the input via MEDS-Transforms, see +[this example script](https://gist.github.com/mmcdermott/34194e484d7b2a2f68967b9bbccfb35b) -See [`/tests/test_integration.py`](https://github.com/mmcdermott/MEDS_Tabular_AutoML/blob/main/tests/test_integration.py) for a local example of the end-to-end pipeline being run on synthetic data. This script is a functional test that is also run with `pytest` to verify the correctness of the algorithm. +See [`/tests/test_integration.py`](https://github.com/mmcdermott/MEDS_Tabular_AutoML/blob/main/tests/test_integration.py) for a local example of the end-to-end pipeline (minus re-sharding) being run on synthetic data. This script is a functional test that is also run with `pytest` to verify the correctness of the algorithm. ## Why MEDS-Tab? @@ -73,6 +74,28 @@ By following these steps, you can seamlessly transform your dataset, define nece ## Core CLI Scripts Overview +0. First, if your data is not already sharded to the degree you want and in a manner that subdivides your + splits with the format `"$SPLIT_NAME/\d+.parquet"`, where `$SPLIT_NAME` does not contain slashes, you will + need to re-shard your data. This can be done via the + [MEDS-Transforms](https://github.com/mmcdermott/MEDS_transforms) library, which is not included in this + repository. Having data sharded by split _is a necessary step_ to ensure that the data is efficiently + processed in parallel. You can easily re-shard your input MEDS cohort in the environment into which this + package is installed with the following command: + + ```console + # Re-shard pipeline + # $MIMICIV_MEDS_DIR is the directory containing the input, MEDS v0.3 formatted MIMIC-IV data + # $MEDS_TAB_COHORT_DIR is the directory where the re-sharded MEDS dataset will be stored, and where your model + # will store cached files during processing by default. + # $N_PATIENTS_PER_SHARD is the number of patients per shard you want to use. + MEDS_transform-reshard_to_split \ + input_dir="$MIMICIV_MEDS_DIR" \ + cohort_dir="$MEDS_TAB_COHORT_DIR" \ + 'stages=["reshard_to_split"]' \ + stage="reshard_to_split" \ + stage_configs.reshard_to_split.n_patients_per_shard=$N_PATIENTS_PER_SHARD + ``` + 1. **`meds-tab-describe`**: This command processes MEDS data shards to compute the frequencies of different code types. It differentiates codes into the following categories: - time-series codes (codes with timestamps) diff --git a/pyproject.toml b/pyproject.toml index 2fe30be..2bc4079 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,11 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = ["polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost", "scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins"] +dependencies = [ + "polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost", + "scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins", "meds==0.3", + "MEDS-transforms==0.0.4", +] [project.scripts] meds-tab-describe = "MEDS_tabular_automl.scripts.describe_codes:main" diff --git a/src/MEDS_tabular_automl/configs/default.yaml b/src/MEDS_tabular_automl/configs/default.yaml index 8f8513c..82a2164 100644 --- a/src/MEDS_tabular_automl/configs/default.yaml +++ b/src/MEDS_tabular_automl/configs/default.yaml @@ -1,11 +1,13 @@ MEDS_cohort_dir: ??? +output_cohort_dir: ??? do_overwrite: False seed: 1 tqdm: False worker: 0 loguru_init: False -log_dir: ${output_dir}/.logs/ +log_dir: ${output_cohort_dir}/.logs/ +cache_dir: ${output_cohort_dir}/.cache hydra: verbose: False diff --git a/src/MEDS_tabular_automl/configs/describe_codes.yaml b/src/MEDS_tabular_automl/configs/describe_codes.yaml index d171513..ec980bf 100644 --- a/src/MEDS_tabular_automl/configs/describe_codes.yaml +++ b/src/MEDS_tabular_automl/configs/describe_codes.yaml @@ -2,13 +2,8 @@ defaults: - default - _self_ -# split we wish to get metadata for -split: train -# Raw data, must have a subdirectory "train" with the training data split -input_dir: ${MEDS_cohort_dir}/final_cohort/${split} +input_dir: ${output_cohort_dir}/data # Where to store output code frequency data -cache_dir: ${MEDS_cohort_dir}/.cache -output_dir: ${MEDS_cohort_dir} -output_filepath: ${output_dir}/code_metadata.parquet +output_filepath: ${output_cohort_dir}/metadata/codes.parquet name: describe_codes diff --git a/src/MEDS_tabular_automl/configs/launch_xgboost.yaml b/src/MEDS_tabular_automl/configs/launch_xgboost.yaml index 3dce8bc..4fc1383 100644 --- a/src/MEDS_tabular_automl/configs/launch_xgboost.yaml +++ b/src/MEDS_tabular_automl/configs/launch_xgboost.yaml @@ -9,13 +9,12 @@ defaults: task_name: task # Task cached data dir -input_dir: ${MEDS_cohort_dir}/${task_name}/task_cache +input_dir: ${output_cohort_dir}/${task_name}/task_cache # Directory with task labels -input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort +input_label_dir: ${output_cohort_dir}/${task_name}/labels/ # Where to output the model and cached data -output_dir: ${MEDS_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S} -output_filepath: ${output_dir}/model_metadata.parquet -cache_dir: ${MEDS_cohort_dir}/.cache +model_dir: ${output_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S} +output_filepath: ${model_dir}/model_metadata.json # Model parameters model_params: @@ -31,13 +30,9 @@ model_params: keep_data_in_memory: True binarize_task: True -hydra: - verbose: False - sweep: - dir: ${output_dir}/.logs/ - run: - dir: ${output_dir}/.logs/ +log_dir: ${model_dir}/.logs/ +hydra: # Optuna Sweeper sweeper: sampler: diff --git a/src/MEDS_tabular_automl/configs/tabularization.yaml b/src/MEDS_tabular_automl/configs/tabularization.yaml index dd40e3f..cf03d63 100644 --- a/src/MEDS_tabular_automl/configs/tabularization.yaml +++ b/src/MEDS_tabular_automl/configs/tabularization.yaml @@ -5,8 +5,8 @@ defaults: # Raw data # Where the code metadata is stored -input_code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet -input_dir: ${MEDS_cohort_dir}/final_cohort -output_dir: ${MEDS_cohort_dir}/tabularize +input_code_metadata_fp: ${output_cohort_dir}/metadata/codes.parquet +input_dir: ${output_cohort_dir}/data +output_dir: ${output_cohort_dir}/tabularize name: tabularization diff --git a/src/MEDS_tabular_automl/configs/tabularization/default.yaml b/src/MEDS_tabular_automl/configs/tabularization/default.yaml index d11dd62..3f8761c 100644 --- a/src/MEDS_tabular_automl/configs/tabularization/default.yaml +++ b/src/MEDS_tabular_automl/configs/tabularization/default.yaml @@ -1,7 +1,7 @@ # User inputs allowed_codes: null min_code_inclusion_frequency: 10 -filtered_code_metadata_fp: ${MEDS_cohort_dir}/tabularized_code_metadata.parquet +filtered_code_metadata_fp: ${output_cohort_dir}/tabularized_code_metadata.parquet window_sizes: - "1d" - "7d" diff --git a/src/MEDS_tabular_automl/configs/task_specific_caching.yaml b/src/MEDS_tabular_automl/configs/task_specific_caching.yaml index eb1c98e..c002dfc 100644 --- a/src/MEDS_tabular_automl/configs/task_specific_caching.yaml +++ b/src/MEDS_tabular_automl/configs/task_specific_caching.yaml @@ -5,10 +5,13 @@ defaults: task_name: task # Tabularized Data -input_dir: ${MEDS_cohort_dir}/tabularize +input_dir: ${output_cohort_dir}/tabularize # Where the labels are stored, with columns patient_id, timestamp, label -input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels/final_cohort +input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels # Where to output the task specific tabularized data -output_dir: ${MEDS_cohort_dir}/${task_name}/task_cache +output_dir: ${output_cohort_dir}/${task_name}/task_cache +output_label_dir: ${output_cohort_dir}/${task_name}/labels + +label_column: "boolean_value" name: task_specific_caching diff --git a/src/MEDS_tabular_automl/describe_codes.py b/src/MEDS_tabular_automl/describe_codes.py index 2ded0e8..5a86d0e 100644 --- a/src/MEDS_tabular_automl/describe_codes.py +++ b/src/MEDS_tabular_automl/describe_codes.py @@ -82,7 +82,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame: >>> data = pl.DataFrame({ ... 'patient_id': [1, 1, 2, 2, 3, 3, 3], ... 'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'], - ... 'timestamp': [ + ... 'time': [ ... None, ... datetime(2021, 1, 1), ... None, @@ -91,7 +91,7 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame: ... datetime(2021, 1, 4), ... None ... ], - ... 'numerical_value': [1, None, 2, 2, None, None, 3] + ... 'numeric_value': [1, None, 2, 2, None, None, 3] ... }).lazy() >>> assert ( ... convert_to_freq_dict(compute_feature_frequencies(data).lazy()) == { @@ -101,29 +101,29 @@ def compute_feature_frequencies(shard_df: DF_T) -> pl.DataFrame: ... ) """ static_df = shard_df.filter( - pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null() + pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_null() ) static_code_freqs_df = static_df.group_by("code").agg(pl.count("code").alias("count")).collect() static_code_freqs = { row["code"] + "/static/present": row["count"] for row in static_code_freqs_df.iter_rows(named=True) } - static_value_df = static_df.filter(pl.col("numerical_value").is_not_null()) + static_value_df = static_df.filter(pl.col("numeric_value").is_not_null()) static_value_freqs_df = ( - static_value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect() + static_value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect() ) static_value_freqs = { row["code"] + "/static/first": row["count"] for row in static_value_freqs_df.iter_rows(named=True) } ts_df = shard_df.filter( - pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_not_null() + pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("time").is_not_null() ) code_freqs_df = ts_df.group_by("code").agg(pl.count("code").alias("count")).collect() code_freqs = {row["code"] + "/code": row["count"] for row in code_freqs_df.iter_rows(named=True)} - value_df = ts_df.filter(pl.col("numerical_value").is_not_null()) - value_freqs_df = value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect() + value_df = ts_df.filter(pl.col("numeric_value").is_not_null()) + value_freqs_df = value_df.group_by("code").agg(pl.count("numeric_value").alias("count")).collect() value_freqs = {row["code"] + "/value": row["count"] for row in value_freqs_df.iter_rows(named=True)} combined_freqs = {**static_code_freqs, **static_value_freqs, **code_freqs, **value_freqs} @@ -222,23 +222,23 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame: >>> fp = NamedTemporaryFile() >>> pl.DataFrame({ ... "code": ["A", "A", "A", "A", "D", "D", "E", "E"], - ... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"], - ... "numerical_value": [1, None, 2, 2, None, 5, None, 3] + ... "time": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"], + ... "numeric_value": [1, None, 2, 2, None, 5, None, 3] ... }).write_parquet(fp.name) >>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect() shape: (6, 3) - ┌──────┬────────────┬─────────────────┐ - │ code ┆ timestamp ┆ numerical_value │ - │ --- ┆ --- ┆ --- │ - │ str ┆ str ┆ i64 │ - ╞══════╪════════════╪═════════════════╡ - │ A ┆ 2021-01-01 ┆ null │ - │ A ┆ 2021-01-01 ┆ null │ - │ D ┆ null ┆ null │ - │ D ┆ null ┆ null │ - │ E ┆ 2021-01-03 ┆ null │ - │ E ┆ 2021-01-04 ┆ 3 │ - └──────┴────────────┴─────────────────┘ + ┌──────┬────────────┬───────────────┐ + │ code ┆ time ┆ numeric_value │ + │ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ i64 │ + ╞══════╪════════════╪═══════════════╡ + │ A ┆ 2021-01-01 ┆ null │ + │ A ┆ 2021-01-01 ┆ null │ + │ D ┆ null ┆ null │ + │ D ┆ null ┆ null │ + │ E ┆ 2021-01-03 ┆ null │ + │ E ┆ 2021-01-04 ┆ 3 │ + └──────┴────────────┴───────────────┘ >>> fp.close() """ df = pl.scan_parquet(fp) @@ -257,8 +257,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame: clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes) ] - is_static_code = pl.col("timestamp").is_null() - is_numeric_code = pl.col("numerical_value").is_not_null() + is_static_code = pl.col("time").is_null() + is_numeric_code = pl.col("numeric_value").is_not_null() rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns) rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns) rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code @@ -268,8 +268,8 @@ def filter_parquet(fp: Path, allowed_codes: list[str]) -> pl.LazyFrame: df = df.with_columns( pl.when(rare_static_value | rare_ts_value) .then(None) - .otherwise(pl.col("numerical_value")) - .alias("numerical_value") + .otherwise(pl.col("numeric_value")) + .alias("numeric_value") ) # Drop rows with rare codes df = df.filter(~(rare_static_code | rare_ts_code)) diff --git a/src/MEDS_tabular_automl/generate_static_features.py b/src/MEDS_tabular_automl/generate_static_features.py index 5ad4b30..8ff4003 100644 --- a/src/MEDS_tabular_automl/generate_static_features.py +++ b/src/MEDS_tabular_automl/generate_static_features.py @@ -119,7 +119,7 @@ def summarize_static_measurements( code_subset = df.filter(pl.col("code").is_in(static_first_codes)) first_code_subset = code_subset.group_by(pl.col("patient_id")).first().collect() static_value_pivot_df = first_code_subset.pivot( - index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None + index=["patient_id"], columns=["code"], values=["numeric_value"], aggregate_function=None ) # rename code to feature name remap_cols = { diff --git a/src/MEDS_tabular_automl/generate_summarized_reps.py b/src/MEDS_tabular_automl/generate_summarized_reps.py index f9afb70..c2dcb86 100644 --- a/src/MEDS_tabular_automl/generate_summarized_reps.py +++ b/src/MEDS_tabular_automl/generate_summarized_reps.py @@ -59,7 +59,7 @@ def get_rolling_window_indicies(index_df: pl.LazyFrame, window_size: str) -> pl. timedelta = pd.Timedelta(window_size) return ( index_df.with_row_index("index") - .rolling(index_column="timestamp", period=timedelta, group_by="patient_id") + .rolling(index_column="time", period=timedelta, group_by="patient_id") .agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")]) .select(pl.col("min_index", "max_index")) .collect() @@ -133,11 +133,11 @@ def compute_agg( """Applies aggregation to a sparse matrix using rolling window indices derived from a DataFrame. Dataframe is expected to only have the relevant columns for aggregating. It should have the patient_id and - timestamp columns, and then only code columns if agg is a code aggregation or only value columns if it is + time columns, and then only code columns if agg is a code aggregation or only value columns if it is a value aggreagation. Args: - index_df: The DataFrame with 'patient_id' and 'timestamp' columns used for grouping. + index_df: The DataFrame with 'patient_id' and 'time' columns used for grouping. matrix: The sparse matrix to be aggregated. window_size: The string defining the rolling window size. agg: The string specifying the aggregation method. @@ -149,11 +149,11 @@ def compute_agg( """ group_df = ( index_df.with_row_index("index") - .group_by(["patient_id", "timestamp"], maintain_order=True) + .group_by(["patient_id", "time"], maintain_order=True) .agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")]) .collect() ) - index_df = group_df.lazy().select(pl.col("patient_id", "timestamp")) + index_df = group_df.lazy().select(pl.col("patient_id", "time")) windows = group_df.select(pl.col("min_index", "max_index")) logger.info("Step 1.5: Running sparse aggregation.") matrix = aggregate_matrix(windows, matrix, agg, num_features, use_tqdm) diff --git a/src/MEDS_tabular_automl/generate_ts_features.py b/src/MEDS_tabular_automl/generate_ts_features.py index 65f95ab..331f65e 100644 --- a/src/MEDS_tabular_automl/generate_ts_features.py +++ b/src/MEDS_tabular_automl/generate_ts_features.py @@ -57,7 +57,7 @@ def get_long_code_df( .to_series() .to_numpy() ) - assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type" + assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type" data = np.ones(df.select(pl.len()).collect().item(), dtype=np.bool_) return data, (rows, cols) @@ -76,9 +76,7 @@ def get_long_value_df( the CSR sparse matrix. """ column_to_int = {feature_name_to_code(col): i for i, col in enumerate(ts_columns)} - value_df = ( - df.with_row_index("index").drop_nulls("numerical_value").filter(pl.col("code").is_in(ts_columns)) - ) + value_df = df.with_row_index("index").drop_nulls("numeric_value").filter(pl.col("code").is_in(ts_columns)) rows = value_df.select(pl.col("index")).collect().to_series().to_numpy() cols = ( value_df.with_columns(pl.col("code").cast(str).replace(column_to_int).cast(int).alias("value_index")) @@ -87,8 +85,8 @@ def get_long_value_df( .to_series() .to_numpy() ) - assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type" - data = value_df.select(pl.col("numerical_value")).collect().to_series().to_numpy() + assert np.issubdtype(cols.dtype, np.number), "numeric_value must be a numerical type" + data = value_df.select(pl.col("numeric_value")).collect().to_series().to_numpy() return data, (rows, cols) @@ -109,7 +107,7 @@ def summarize_dynamic_measurements( of aggregated values. """ logger.info("Generating Sparse matrix for Time Series Features") - id_cols = ["patient_id", "timestamp"] + id_cols = ["patient_id", "time"] # Confirm dataframe is sorted check_df = df.select(pl.col(id_cols)) @@ -117,7 +115,7 @@ def summarize_dynamic_measurements( # Generate sparse matrix if agg in CODE_AGGREGATIONS: - code_df = df.drop(*(id_cols + ["numerical_value"])) + code_df = df.drop(*(id_cols + ["numeric_value"])) data, (rows, cols) = get_long_code_df(code_df, ts_columns) elif agg in VALUE_AGGREGATIONS: value_df = df.drop(*id_cols) diff --git a/src/MEDS_tabular_automl/scripts/cache_task.py b/src/MEDS_tabular_automl/scripts/cache_task.py index 62df9b8..15c194b 100644 --- a/src/MEDS_tabular_automl/scripts/cache_task.py +++ b/src/MEDS_tabular_automl/scripts/cache_task.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Aggregates time-series data for feature columns across different window sizes.""" +from functools import partial from importlib.resources import files from pathlib import Path @@ -10,6 +11,7 @@ import scipy.sparse as sp from omegaconf import DictConfig +from ..describe_codes import filter_parquet, get_feature_columns from ..file_name import list_subdir_files from ..mapper import wrap as rwlock_wrap from ..utils import ( @@ -17,7 +19,9 @@ STATIC_CODE_AGGREGATION, STATIC_VALUE_AGGREGATION, VALUE_AGGREGATIONS, + get_events_df, get_shard_prefix, + get_unique_time_events_df, hydra_loguru_init, load_matrix, load_tqdm, @@ -37,6 +41,10 @@ ] +def write_lazyframe(df: pl.LazyFrame, fp: Path): + df.collect().write_parquet(fp, use_pyarrow=True) + + def generate_row_cached_matrix(matrix: sp.coo_array, label_df: pl.LazyFrame) -> sp.coo_array: """Generates row-cached matrix for a given matrix and label DataFrame. @@ -80,31 +88,51 @@ def main(cfg: DictConfig): tabularization_tasks = list_subdir_files(cfg.input_dir, "npz") np.random.shuffle(tabularization_tasks) + label_dir = Path(cfg.input_label_dir) + label_df = pl.scan_parquet(label_dir / "**/*.parquet").rename( + { + "prediction_time": "time", + cfg.label_column: "label", + } + ) + + feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp) + # iterate through them for data_fp in iter_wrapper(tabularization_tasks): # parse as time series agg split, shard_num, window_size, code_type, agg_name = Path(data_fp).with_suffix("").parts[-5:] - label_fp = Path(cfg.input_label_dir) / split / f"{shard_num}.parquet" - out_fp = (Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, data_fp)).with_suffix(".npz") - assert label_fp.exists(), f"Output file {label_fp} does not exist." - def read_fn(fps): - matrix_fp, label_fp = fps - return load_matrix(matrix_fp), pl.scan_parquet(label_fp) + raw_data_fp = Path(cfg.output_cohort_dir) / "data" / split / f"{shard_num}.parquet" + raw_data_df = filter_parquet(raw_data_fp, cfg.tabularization._resolved_codes) + raw_data_df = ( + get_unique_time_events_df(get_events_df(raw_data_df, feature_columns)) + .with_row_index("event_id") + .select("patient_id", "time", "event_id") + ) + shard_label_df = label_df.join( + raw_data_df.select("patient_id").unique(), on="patient_id", how="inner" + ).join_asof(other=raw_data_df, by="patient_id", on="time") - def compute_fn(shard_dfs): - matrix, label_df = shard_dfs - cache_matrix = generate_row_cached_matrix(matrix, label_df) - return cache_matrix + shard_label_fp = Path(cfg.output_label_dir) / split / f"{shard_num}.parquet" + rwlock_wrap( + raw_data_fp, + shard_label_fp, + pl.scan_parquet, + write_lazyframe, + lambda df: shard_label_df, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) - def write_fn(cache_matrix, out_fp): - write_df(cache_matrix, out_fp, do_overwrite=cfg.do_overwrite) + out_fp = (Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, data_fp)).with_suffix(".npz") + compute_fn = partial(generate_row_cached_matrix, label_df=shard_label_df) + write_fn = partial(write_df, do_overwrite=cfg.do_overwrite) - in_fps = [data_fp, label_fp] rwlock_wrap( - in_fps, + data_fp, out_fp, - read_fn, + load_matrix, write_fn, compute_fn, do_overwrite=cfg.do_overwrite, diff --git a/src/MEDS_tabular_automl/scripts/describe_codes.py b/src/MEDS_tabular_automl/scripts/describe_codes.py index fdee111..c29b542 100644 --- a/src/MEDS_tabular_automl/scripts/describe_codes.py +++ b/src/MEDS_tabular_automl/scripts/describe_codes.py @@ -8,7 +8,7 @@ import numpy as np import polars as pl from loguru import logger -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from ..describe_codes import ( compute_feature_frequencies, @@ -36,15 +36,6 @@ def main(cfg: DictConfig): if not cfg.loguru_init: hydra_loguru_init() - # Store Config - output_dir = Path(cfg.output_dir) - output_dir.mkdir(exist_ok=True, parents=True) - OmegaConf.save(cfg, output_dir / "config.yaml") - - # Create output dir - input_dir = Path(cfg.input_dir) - input_dir.mkdir(exist_ok=True, parents=True) - # 0. Identify Output Columns and Frequencies logger.info("Iterating through shards and caching feature frequencies.") diff --git a/src/MEDS_tabular_automl/scripts/launch_xgboost.py b/src/MEDS_tabular_automl/scripts/launch_xgboost.py index 2f52dc7..25bd7de 100644 --- a/src/MEDS_tabular_automl/scripts/launch_xgboost.py +++ b/src/MEDS_tabular_automl/scripts/launch_xgboost.py @@ -1,5 +1,4 @@ from collections.abc import Callable, Mapping -from datetime import datetime from importlib.resources import files from pathlib import Path @@ -440,10 +439,10 @@ def main(cfg: DictConfig) -> float: # print("Held Out Iterator Time: \n", model.iheld_out._profile_durations()) # save model - save_dir = Path(cfg.output_dir) - save_dir.mkdir(parents=True, exist_ok=True) - model_time = datetime.now().strftime("%H%M%S%f") - model.model.save_model(save_dir / f"{auc:.4f}_model_{model_time}.json") + output_fp = Path(cfg.output_filepath) + output_fp.parent.mkdir(parents=True, exist_ok=True) + + model.model.save_model(output_fp) except Exception as e: logger.error(f"Error occurred: {e}") auc = 0.0 diff --git a/src/MEDS_tabular_automl/utils.py b/src/MEDS_tabular_automl/utils.py index c398a39..49de128 100644 --- a/src/MEDS_tabular_automl/utils.py +++ b/src/MEDS_tabular_automl/utils.py @@ -284,7 +284,7 @@ def write_df(df: pl.LazyFrame | pl.DataFrame | coo_array, fp: Path, do_overwrite def get_events_df(shard_df: pl.LazyFrame, feature_columns) -> pl.LazyFrame: - """Extracts and filters an Events LazyFrame with one row per observation (timestamps can be duplicated). + """Extracts and filters an Events LazyFrame with one row per observation (times can be duplicated). Args: shard_df: The LazyFrame shard from which to extract events. @@ -296,28 +296,26 @@ def get_events_df(shard_df: pl.LazyFrame, feature_columns) -> pl.LazyFrame: # Filter out feature_columns that were not present in the training set raw_feature_columns = ["/".join(c.split("/")[:-1]) for c in feature_columns] shard_df = shard_df.filter(pl.col("code").is_in(raw_feature_columns)) - # Drop rows with missing timestamp or code to get events - ts_shard_df = shard_df.drop_nulls(subset=["timestamp", "code"]) + # Drop rows with missing time or code to get events + ts_shard_df = shard_df.drop_nulls(subset=["time", "code"]) return ts_shard_df def get_unique_time_events_df(events_df: pl.LazyFrame) -> pl.LazyFrame: - """Ensures all timestamps in the events LazyFrame are unique and sorted by patient_id and timestamp. + """Ensures all times in the events LazyFrame are unique and sorted by patient_id and time. Args: events_df: Events LazyFrame to process. Returns: - A LazyFrame with unique timestamps, sorted by patient_id and timestamp. + A LazyFrame with unique times, sorted by patient_id and time. """ - assert events_df.select(pl.col("timestamp")).null_count().collect().item() == 0 + assert events_df.select(pl.col("time")).null_count().collect().item() == 0 # Check events_df is sorted - so it aligns with the ts_matrix we generate later in the pipeline events_df = ( - events_df.drop_nulls("timestamp") - .select(pl.col(["patient_id", "timestamp"])) - .unique(maintain_order=True) + events_df.drop_nulls("time").select(pl.col(["patient_id", "time"])).unique(maintain_order=True) ) - assert events_df.sort(by=["patient_id", "timestamp"]).collect().equals(events_df.collect()) + assert events_df.sort(by=["patient_id", "time"]).collect().equals(events_df.collect()) return events_df diff --git a/tests/test_integration.py b/tests/test_integration.py index ecc229b..d22eac5 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -45,10 +45,12 @@ def run_command(script: str, args: list[str], hydra_kwargs: dict[str, str], test def test_integration(): # Step 0: Setup Environment with tempfile.TemporaryDirectory() as d: - MEDS_cohort_dir = Path(d) / "processed" + MEDS_cohort_dir = Path(d) / "MEDS_cohort_dir" + output_cohort_dir = Path(d) / "output_cohort_dir" - describe_codes_config = { + shared_config = { "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "output_cohort_dir": str(output_cohort_dir.resolve()), "do_overwrite": False, "seed": 1, "hydra.verbose": True, @@ -56,6 +58,8 @@ def test_integration(): "loguru_init": True, } + describe_codes_config = {**shared_config} + with initialize( version_base=None, config_path="../src/MEDS_tabular_automl/configs/" ): # path to config.yaml @@ -63,16 +67,20 @@ def test_integration(): cfg = compose(config_name="describe_codes", overrides=overrides) # config.yaml # Create the directories - (MEDS_cohort_dir / "final_cohort").mkdir(parents=True, exist_ok=True) + (output_cohort_dir / "data").mkdir(parents=True, exist_ok=True) # Store MEDS outputs + all_data = [] for split, data in MEDS_OUTPUTS.items(): - file_path = MEDS_cohort_dir / "final_cohort" / f"{split}.parquet" + file_path = output_cohort_dir / "data" / f"{split}.parquet" file_path.parent.mkdir(exist_ok=True) - df = pl.read_csv(StringIO(data)) - df.with_columns(pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet( - file_path + df = pl.read_csv(StringIO(data)).with_columns( + pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f") ) + df.write_parquet(file_path) + all_data.append(df) + + all_data = pl.concat(all_data, how="diagonal_relaxed").sort(by=["patient_id", "time"]) # Check the files are not empty meds_files = list_subdir_files(Path(cfg.input_dir), "parquet") @@ -82,7 +90,7 @@ def test_integration(): for f in meds_files: assert pl.read_parquet(f).shape[0] > 0, "MEDS Data Tabular Dataframe Should not be Empty!" split_json = json.load(StringIO(SPLITS_JSON)) - splits_fp = MEDS_cohort_dir / "splits.json" + splits_fp = output_cohort_dir / ".shards.json" json.dump(split_json, splits_fp.open("w")) # Step 1: Run the describe_codes script @@ -92,7 +100,6 @@ def test_integration(): describe_codes_config, "describe_codes", ) - assert (Path(cfg.output_dir) / "config.yaml").is_file() assert Path(cfg.output_filepath).is_file() feature_columns = get_feature_columns(cfg.output_filepath) @@ -104,12 +111,7 @@ def test_integration(): # Step 2: Run the static data tabularization script tabularize_config = { - "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), - "do_overwrite": False, - "seed": 1, - "hydra.verbose": True, - "tqdm": False, - "loguru_init": True, + **shared_config, "tabularization.min_code_inclusion_frequency": 1, "tabularization.window_sizes": "[30d,365d,full]", } @@ -158,12 +160,7 @@ def test_integration(): # Step 3: Run the time series tabularization script tabularize_config = { - "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), - "do_overwrite": False, - "seed": 1, - "hydra.verbose": True, - "tqdm": False, - "loguru_init": True, + **shared_config, "tabularization.min_code_inclusion_frequency": 1, "tabularization.window_sizes": "[30d,365d,full]", } @@ -205,12 +202,7 @@ def test_integration(): ) # Step 4: Run the task_specific_caching script cache_config = { - "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), - "do_overwrite": False, - "seed": 1, - "hydra.verbose": True, - "tqdm": False, - "loguru_init": True, + **shared_config, "tabularization.min_code_inclusion_frequency": 1, "tabularization.window_sizes": "[30d,365d,full]", } @@ -219,22 +211,15 @@ def test_integration(): ): # path to config.yaml overrides = [f"{k}={v}" for k, v in cache_config.items()] cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml - # Create fake labels - for f in list_subdir_files(Path(cfg.MEDS_cohort_dir) / "final_cohort", "parquet"): - df = pl.scan_parquet(f) - df = get_unique_time_events_df(get_events_df(df, feature_columns)).collect() - pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]]) - df = df.with_columns(pl.Series(name="label", values=pseudo_labels)) - df = df.select(pl.col(["patient_id", "timestamp", "label"])) - df = df.with_row_index("event_id") - - split = f.parent.stem - shard_num = f.stem - out_f = Path(cfg.input_label_dir) / Path( - get_shard_prefix(Path(cfg.MEDS_cohort_dir) / "final_cohort", f) - ).with_suffix(".parquet") - out_f.parent.mkdir(parents=True, exist_ok=True) - df.write_parquet(out_f) + + df = get_unique_time_events_df(get_events_df(all_data.lazy(), feature_columns)).collect() + pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]]) + df = df.with_columns(pl.Series(name="boolean_value", values=pseudo_labels)) + df = df.select("patient_id", pl.col("time").alias("prediction_time"), "boolean_value") + + out_fp = Path(cfg.input_label_dir) / "0.parquet" + out_fp.parent.mkdir(parents=True, exist_ok=True) + df.write_parquet(out_fp) stderr, stdout_ws = run_command("generate-subsets", ["[30d]"], {}, "generate-subsets window_sizes") stderr, stdout_agg = run_command( diff --git a/tests/test_tabularize.py b/tests/test_tabularize.py index 3f64574..130721c 100644 --- a/tests/test_tabularize.py +++ b/tests/test_tabularize.py @@ -35,7 +35,7 @@ SPLITS_JSON = """{"train/0": [239684, 1195293], "train/1": [68729, 814703], "tuning/0": [754281], "held_out/0": [1500733]}""" # noqa: E501 MEDS_TRAIN_0 = """ -patient_id,code,timestamp,numerical_value +patient_id,code,time,numeric_value 239684,HEIGHT,,175.271115221764 239684,EYE_COLOR//BROWN,, 239684,DOB,1980-12-28T00:00:00.000000, @@ -68,7 +68,7 @@ 1195293,DISCHARGE,2010-06-20T20:50:04.000000, """ MEDS_TRAIN_1 = """ -patient_id,code,timestamp,numerical_value +patient_id,code,time,numeric_value 68729,EYE_COLOR//HAZEL,, 68729,HEIGHT,,160.3953106166676 68729,DOB,1978-03-09T00:00:00.000000, @@ -85,7 +85,7 @@ 814703,DISCHARGE,2010-02-05T07:02:30.000000, """ MEDS_HELD_OUT_0 = """ -patient_id,code,timestamp,numerical_value +patient_id,code,time,numeric_value 1500733,HEIGHT,,158.60131573580904 1500733,EYE_COLOR//BROWN,, 1500733,DOB,1986-07-20T00:00:00.000000, @@ -99,7 +99,7 @@ 1500733,DISCHARGE,2010-06-03T16:44:26.000000, """ MEDS_TUNING_0 = """ -patient_id,code,timestamp,numerical_value +patient_id,code,time,numeric_value 754281,EYE_COLOR//BROWN,, 754281,HEIGHT,,166.22261567137025 754281,DOB,1988-12-19T00:00:00.000000, @@ -148,10 +148,12 @@ def test_tabularize(): with tempfile.TemporaryDirectory() as d: - MEDS_cohort_dir = Path(d) / "processed" + MEDS_cohort_dir = Path(d) / "MEDS_cohort_dir" + output_cohort_dir = Path(d) / "output_cohort_dir" - describe_codes_config = { + shared_config = { "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "output_cohort_dir": str(output_cohort_dir.resolve()), "do_overwrite": False, "seed": 1, "hydra.verbose": True, @@ -159,6 +161,8 @@ def test_tabularize(): "loguru_init": True, } + describe_codes_config = {**shared_config} + with initialize( version_base=None, config_path="../src/MEDS_tabular_automl/configs/" ): # path to config.yaml @@ -166,16 +170,20 @@ def test_tabularize(): cfg = compose(config_name="describe_codes", overrides=overrides) # config.yaml # Create the directories - (MEDS_cohort_dir / "final_cohort").mkdir(parents=True, exist_ok=True) + (output_cohort_dir / "data").mkdir(parents=True, exist_ok=True) # Store MEDS outputs + all_data = [] for split, data in MEDS_OUTPUTS.items(): - file_path = MEDS_cohort_dir / "final_cohort" / f"{split}.parquet" + file_path = output_cohort_dir / "data" / f"{split}.parquet" file_path.parent.mkdir(exist_ok=True) - df = pl.read_csv(StringIO(data)) - df.with_columns(pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet( - file_path + df = pl.read_csv(StringIO(data)).with_columns( + pl.col("time").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f") ) + df.write_parquet(file_path) + all_data.append(df) + + all_data = pl.concat(all_data, how="diagonal_relaxed").sort(by=["patient_id", "time"]) # Check the files are not empty meds_files = list_subdir_files(Path(cfg.input_dir), "parquet") @@ -185,12 +193,11 @@ def test_tabularize(): for f in meds_files: assert pl.read_parquet(f).shape[0] > 0, "MEDS Data Tabular Dataframe Should not be Empty!" split_json = json.load(StringIO(SPLITS_JSON)) - splits_fp = MEDS_cohort_dir / "splits.json" + splits_fp = output_cohort_dir / ".shards.json" json.dump(split_json, splits_fp.open("w")) # Step 1: Describe Codes - compute code frequencies describe_codes.main(cfg) - assert (Path(cfg.output_dir) / "config.yaml").is_file() assert Path(cfg.output_filepath).is_file() feature_columns = get_feature_columns(cfg.output_filepath) @@ -202,12 +209,7 @@ def test_tabularize(): # Step 2: Tabularization tabularize_static_config = { - "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), - "do_overwrite": False, - "seed": 1, - "hydra.verbose": True, - "tqdm": False, - "loguru_init": True, + **shared_config, "tabularization.min_code_inclusion_frequency": 1, "tabularization.window_sizes": "[30d,365d,full]", } @@ -218,8 +220,11 @@ def test_tabularize(): overrides = [f"{k}={v}" for k, v in tabularize_static_config.items()] cfg = compose(config_name="tabularization", overrides=overrides) # config.yaml tabularize_static.main(cfg) - output_files = list(Path(cfg.output_dir).glob("**/static/**/*.npz")) - actual_files = [get_shard_prefix(Path(cfg.output_dir), each) + ".npz" for each in output_files] + + output_dir = Path(cfg.output_cohort_dir) / "tabularize" + + output_files = list(output_dir.glob("**/static/**/*.npz")) + actual_files = [get_shard_prefix(output_dir, each) + ".npz" for each in output_files] assert set(actual_files) == set(EXPECTED_STATIC_FILES) # Check the files are not empty for f in output_files: @@ -252,9 +257,9 @@ def test_tabularize(): tabularize_time_series.main(cfg) # confirm summary files exist: - output_files = list_subdir_files(cfg.output_dir, "npz") + output_files = list_subdir_files(str(output_dir.resolve()), "npz") actual_files = [ - get_shard_prefix(Path(cfg.output_dir), each) + ".npz" + get_shard_prefix(output_dir, each) + ".npz" for each in output_files if "none/static" not in str(each) ] @@ -282,12 +287,7 @@ def test_tabularize(): # Step 3: Cache Task data cache_config = { - "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), - "do_overwrite": False, - "seed": 1, - "hydra.verbose": True, - "tqdm": False, - "loguru_init": True, + **shared_config, "tabularization.min_code_inclusion_frequency": 1, "tabularization.window_sizes": "[30d,365d,full]", } @@ -299,31 +299,19 @@ def test_tabularize(): cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml # Create fake labels - for f in list_subdir_files(Path(cfg.MEDS_cohort_dir) / "final_cohort", "parquet"): - df = pl.scan_parquet(f) - df = get_unique_time_events_df(get_events_df(df, feature_columns)).collect() - pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]]) - df = df.with_columns(pl.Series(name="label", values=pseudo_labels)) - df = df.select(pl.col(["patient_id", "timestamp", "label"])) - df = df.with_row_index("event_id") - - split = f.parent.stem - shard_num = f.stem - out_f = Path(cfg.input_label_dir) / Path( - get_shard_prefix(Path(cfg.MEDS_cohort_dir) / "final_cohort", f) - ).with_suffix(".parquet") - out_f.parent.mkdir(parents=True, exist_ok=True) - df.write_parquet(out_f) + df = get_unique_time_events_df(get_events_df(all_data.lazy(), feature_columns)).collect() + pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]]) + df = df.with_columns(pl.Series(name="boolean_value", values=pseudo_labels)) + df = df.select("patient_id", pl.col("time").alias("prediction_time"), "boolean_value") + + out_fp = Path(cfg.input_label_dir) / "0.parquet" + out_fp.parent.mkdir(parents=True, exist_ok=True) + df.write_parquet(out_fp) cache_task.main(cfg) xgboost_config_kwargs = { - "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), - "do_overwrite": False, - "seed": 1, - "hydra.verbose": True, - "tqdm": False, - "loguru_init": True, + **shared_config, "tabularization.min_code_inclusion_frequency": 1, "tabularization.window_sizes": "[30d,365d,full]", } @@ -334,8 +322,10 @@ def test_tabularize(): overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()] cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml + output_dir = Path(cfg.output_cohort_dir) / "model" + launch_xgboost.main(cfg) - output_files = list(Path(cfg.output_dir).glob("**/*.json")) + output_files = list(output_dir.glob("**/*.json")) assert len(output_files) == 1 @@ -355,6 +345,7 @@ def test_xgboost_config(): stderr, stdout_agg = run_command("generate-subsets", ["[static/present]"], {}, "generate-subsets aggs") xgboost_config_kwargs = { "MEDS_cohort_dir": MEDS_cohort_dir, + "output_cohort_dir": "blah", "do_overwrite": False, "seed": 1, "hydra.verbose": True, @@ -369,5 +360,4 @@ def test_xgboost_config(): ): # path to config.yaml overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()] cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml - print(cfg.tabularization.window_sizes) assert cfg.tabularization.window_sizes