diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b32a1bd..6aa8294 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -19,23 +19,21 @@ jobs: - name: Checkout uses: actions/checkout@v3 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v3 with: - python-version: "3.11" + python-version: "3.12" - name: Install packages run: | - pip install -e . - pip install pytest - pip install pytest-cov[toml] + pip install -e .[tests] #---------------------------------------------- # run test suite #---------------------------------------------- - name: Run tests run: | - pytest -v --doctest-modules --cov + pytest -v --doctest-modules --cov --ignore=hf_cohort/ - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7540f52..1533f74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,6 +38,7 @@ repos: rev: v2.2.0 hooks: - id: autoflake + args: [--in-place, --remove-all-unused-imports] # python upgrading syntax to newer version - repo: https://github.com/asottile/pyupgrade diff --git a/README.md b/README.md index e1fd344..e7619e4 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,36 @@ This repository consists of two key pieces: what is more advanced is the efficient construction, storage, and loading of tabular features for the candidate AutoML models, enabling a far more extensive search over different featurization strategies. +### Scripts and Examples + +See `tests/test_tabularize_integration.py` for an example of the end-to-end pipeline being run on synthetic data. This +script is a functional test that is also run with `pytest` to verify the correctness of the algorithm. + +#### Core Scripts: + +1. `scripts/identify_columns.py` loads all training shard to identify which feature columns + to generate tabular data for. +2. `scripts/tabularize_static.py` Iterates through shards and generates tabular vectors for + each patient. There is a single row per patient for each shard. +3. `scripts/summarize_over_windows.py` For each shard, iterates through window sizes and aggregations to and + horizontally concatenates the outputs to generate the final tabular representations at every event time for + every patient. +4. `scripts/tabularize_merge` Aligns the time-series window aggregations (generated in the previous step) with + the static tabular vectors and caches them for training. +5. `scripts/hf_cohort/aces_task_extraction.py` Generates the task labels and caches them with the event_id + indexes which align them with the nearest prior event in the tabular data. +6. `scripts/xgboost_sweep.py` Tunes XGboost on methods. Iterates through the labels and corresponding tabular data. + +We run this on an example dataset using the following bash scripts in sequence: + +```bash +bash hf_cohort_shard.sh # processes the dataset into meds format +bash hf_cohort_e2e.sh # performs (steps 1-4 above) +bash hf_cohort/aces_task.sh # generates labels (step 5) +bash xgboost.sh # trains xgboos (step 6) +``` + + ## Feature Construction, Storage, and Loading Tabularization of a (raw) MEDS dataset is done by running the `scripts/data/tabularize.py` script. This script diff --git a/configs/tabularize.yaml b/configs/tabularize.yaml deleted file mode 100644 index 5d94c75..0000000 --- a/configs/tabularize.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# Raw data -MEDS_cohort_dir: ??? -tabularized_data_dir: ??? - -# Pre-processing -min_code_inclusion_frequency: ??? -window_sizes: ??? -codes: null -aggs: - - "code/count" - - "code/time_since_last" - - "code/time_since_first" - - "value/count" - - "value/sum" - - "value/sum_sqd" - - "value/min" - - "value/time_since_min" - - "value/max" - - "value/time_since_max" - - "value/last" - - "value/slope" - - "value/intercept" - - "value/residual/sum" - - "value/residual/sum_sqd" - - -# Sharding -n_patients_per_sub_shard: null - -# Misc -do_overwrite: False -seed: 1 - -# Hydra -hydra: - job: - name: tabularize_step_${now:%Y-%m-%d_%H-%M-%S} - run: - dir: ${tabularized_data_dir}/.logs/etl/${hydra.job.name} - sweep: - dir: ${tabularized_data_dir}/.logs/etl/${hydra.job.name} diff --git a/pyproject.toml b/pyproject.toml index 1aa7f41..8e53854 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,12 @@ -[build-system] -requires = ["setuptools>=61.0"] -build-backend = "setuptools.build_meta" - [project] name = "MEDS_tabularization" version = "0.0.1" authors = [ { name="Matthew McDermott", email="mattmcdermott8@gmail.com" }, + { name="Nassim Oufattole", email="noufattole@gmail.com" }, + { name="Teya Bergamaschi", email="teyabergamaschi@gmail.com" }, ] -description = "TODO" +description = "Scalable Tabularization of MEDS format Time-Series data" readme = "README.md" requires-python = ">=3.12" classifiers = [ @@ -16,12 +14,25 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = ["polars", "pyarrow", "loguru", "hydra-core", "numpy"] +dependencies = ["polars", "pyarrow", "loguru", "hydra-core", "numpy", "scipy", "pandas", "tqdm", "xgboost", "scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins"] + +[project.scripts] +meds-tab-describe = "MEDS_tabular_automl.scripts.describe_codes:main" +meds-tab-tabularize-static = "MEDS_tabular_automl.scripts.tabularize_static:main" +meds-tab-tabularize-time-series = "MEDS_tabular_automl.scripts.tabularize_time_series:main" +meds-tab-cache-task = "MEDS_tabular_automl.scripts.cache_task:main" +meds-tab-xgboost = "MEDS_tabular_automl.scripts.launch_xgboost:main" +meds-tab-xgboost-sweep = "MEDS_tabular_automl.scripts.sweep_xgboost:main" [project.optional-dependencies] dev = ["pre-commit"] tests = ["pytest", "pytest-cov", "rootutils"] +profiling = ["mprofile", "matplotlib"] + +[build-system] +requires = ["setuptools>=61.0", "setuptools-scm>=8.0", "wheel"] +build-backend = "setuptools.build_meta" [project.urls] -Homepage = "https://github.com/mmcdermott/MEDS_polars_functions" -Issues = "https://github.com/mmcdermott/MEDS_polars_functions/issues" +Homepage = "https://github.com/mmcdermott/MEDS_Tabular_AutoML" +Issues = "https://github.com/mmcdermott/MEDS_Tabular_AutoML/issues" diff --git a/src/MEDS_tabular_automl/configs/__init__.py b/src/MEDS_tabular_automl/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MEDS_tabular_automl/configs/default.yaml b/src/MEDS_tabular_automl/configs/default.yaml new file mode 100644 index 0000000..8f8513c --- /dev/null +++ b/src/MEDS_tabular_automl/configs/default.yaml @@ -0,0 +1,17 @@ +MEDS_cohort_dir: ??? +do_overwrite: False +seed: 1 +tqdm: False +worker: 0 +loguru_init: False + +log_dir: ${output_dir}/.logs/ + +hydra: + verbose: False + job: + name: MEDS_TAB_${name}_${worker}_{now:%Y-%m-%d_%H-%M-%S} + sweep: + dir: ${log_dir} + run: + dir: ${log_dir} diff --git a/src/MEDS_tabular_automl/configs/describe_codes.yaml b/src/MEDS_tabular_automl/configs/describe_codes.yaml new file mode 100644 index 0000000..d171513 --- /dev/null +++ b/src/MEDS_tabular_automl/configs/describe_codes.yaml @@ -0,0 +1,14 @@ +defaults: + - default + - _self_ + +# split we wish to get metadata for +split: train +# Raw data, must have a subdirectory "train" with the training data split +input_dir: ${MEDS_cohort_dir}/final_cohort/${split} +# Where to store output code frequency data +cache_dir: ${MEDS_cohort_dir}/.cache +output_dir: ${MEDS_cohort_dir} +output_filepath: ${output_dir}/code_metadata.parquet + +name: describe_codes diff --git a/src/MEDS_tabular_automl/configs/launch_xgboost.yaml b/src/MEDS_tabular_automl/configs/launch_xgboost.yaml new file mode 100644 index 0000000..123846f --- /dev/null +++ b/src/MEDS_tabular_automl/configs/launch_xgboost.yaml @@ -0,0 +1,81 @@ +defaults: + - default + - tabularization: default + - _self_ + +task_name: task +# min code frequency used for modeling, can potentially sweep over different values. +modeling_min_code_freq: 10 + +# Task cached data dir +input_dir: ${MEDS_cohort_dir}/${task_name}/task_cache +# Directory with task labels +input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels +# Where to output the model and cached data +output_dir: ${MEDS_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S} +output_filepath: ${output_dir}/model_metadata.parquet +cache_dir: ${MEDS_cohort_dir}/.cache + +# Model parameters +model_params: + num_boost_round: 1000 + early_stopping_rounds: 5 + model: + booster: gbtree + device: cpu + nthread: 1 + tree_method: hist + objective: binary:logistic + iterator: + keep_data_in_memory: True + binarize_task: True + +# Define search space for Optuna +optuna: + study_name: xgboost_sweep_${now:%Y-%m-%d_%H-%M-%S} + storage: null + load_if_exists: False + direction: minimize + sampler: null + pruner: null + + n_trials: 10 + n_jobs: 1 + show_progress_bar: False + + params: + suggest_categorical: + window_sizes: ${generate_permutations:${tabularization.window_sizes}} + aggs: ${generate_permutations:${tabularization.aggs}} + suggest_float: + eta: + low: .001 + high: 1 + log: True + lambda: + low: .001 + high: 1 + log: True + alpha: + low: .001 + high: 1 + log: True + subsample: + low: 0.5 + high: 1 + min_child_weight: + low: 1e-2 + high: 100 + suggest_int: + num_boost_round: + low: 10 + high: 1000 + max_depth: + low: 2 + high: 16 + min_code_inclusion_frequency: + low: 10 + high: 1_000_000 + log: True + +name: launch_xgboost diff --git a/src/MEDS_tabular_automl/configs/tabularization.yaml b/src/MEDS_tabular_automl/configs/tabularization.yaml new file mode 100644 index 0000000..dd40e3f --- /dev/null +++ b/src/MEDS_tabular_automl/configs/tabularization.yaml @@ -0,0 +1,12 @@ +defaults: + - default + - tabularization: default + - _self_ + +# Raw data +# Where the code metadata is stored +input_code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet +input_dir: ${MEDS_cohort_dir}/final_cohort +output_dir: ${MEDS_cohort_dir}/tabularize + +name: tabularization diff --git a/src/MEDS_tabular_automl/configs/tabularization/__init__.py b/src/MEDS_tabular_automl/configs/tabularization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MEDS_tabular_automl/configs/tabularization/default.yaml b/src/MEDS_tabular_automl/configs/tabularization/default.yaml new file mode 100644 index 0000000..d11dd62 --- /dev/null +++ b/src/MEDS_tabular_automl/configs/tabularization/default.yaml @@ -0,0 +1,22 @@ +# User inputs +allowed_codes: null +min_code_inclusion_frequency: 10 +filtered_code_metadata_fp: ${MEDS_cohort_dir}/tabularized_code_metadata.parquet +window_sizes: + - "1d" + - "7d" + - "30d" + - "365d" + - "full" +aggs: + - "static/present" + - "static/first" + - "code/count" + - "value/count" + - "value/sum" + - "value/sum_sqd" + - "value/min" + - "value/max" + +# Resolved inputs +_resolved_codes: ${filter_to_codes:${tabularization.allowed_codes},${tabularization.min_code_inclusion_frequency},${tabularization.filtered_code_metadata_fp}} diff --git a/src/MEDS_tabular_automl/configs/task_specific_caching.yaml b/src/MEDS_tabular_automl/configs/task_specific_caching.yaml new file mode 100644 index 0000000..f1ca160 --- /dev/null +++ b/src/MEDS_tabular_automl/configs/task_specific_caching.yaml @@ -0,0 +1,14 @@ +defaults: + - default + - tabularization: default + - _self_ +task_name: task + +# Tabularized Data +input_dir: ${MEDS_cohort_dir}/tabularize +# Where the labels are stored, with columns patient_id, timestamp, label +input_label_dir: ${MEDS_cohort_dir}/${task_name}/labels +# Where to output the task specific tabularized data +output_dir: ${MEDS_cohort_dir}/${task_name}/task_cache + +name: task_specific_caching diff --git a/src/MEDS_tabular_automl/configs/tmp.yaml.yaml b/src/MEDS_tabular_automl/configs/tmp.yaml.yaml new file mode 100644 index 0000000..6312a45 --- /dev/null +++ b/src/MEDS_tabular_automl/configs/tmp.yaml.yaml @@ -0,0 +1,89 @@ +# Raw data +MEDS_cohort_dir: /storage/shared/meds_tabular_ml/ebcl_dataset/processed +tabularized_data_dir: ${MEDS_cohort_dir}/tabularize +task_dir: ${tabularized_data_dir}/task +model_dir: ${MEDS_cohort_dir}/model/model_${now:%Y-%m-%d_%H-%M-%S} +cache_dir: ${tabularized_data_dir}/.cache + +# Pre-processing +min_code_inclusion_frequency: 10 +window_sizes: [1d] +codes: null +aggs: + - "static/present" + # - "static/first" + - "code/count" + - "value/count" + - "value/sum" + - "value/sum_sqd" + - "value/min" + - "value/max" + +dynamic_threshold: 0.01 +numerical_value_threshold: 0.1 + +# Sharding +n_patients_per_sub_shard: null + +# Misc +do_overwrite: False +do_update: True +seed: 1 +tqdm: True +worker: 0 +test: False + +num_boost_round: 1000 +early_stopping_rounds: 5 +model: + booster: gbtree + device: cpu + tree_method: hist + objective: binary:logistic + +iterator: + keep_data_in_memory: True + binarize_task: True + +hydra: + verbose: False + run: + dir: ${model_dir}/.logs/ + +optuna: + storage: null + sampler: null + pruner: null + study_name: /home/teya/xgboost/tmp/xgboost_study_${now:%Y-%m-%d_%H-%M-%S} + direction: minimize + load_if_exists: True + show_progress_bar: False + n_trials: 10 + n_jobs: 3 + + params: + categorical: # choose single item from a list + window_sizes: + [ + [1d], + [7d], + [30d], + [365d], + [full], + [1d, 7d], + [1d, 7d, 30d], + [1d, 7d, 30d, 365d], + ] + # set: # choose any subset from a list + # window_sizes: [1d, 7d, 30d, 365d, full] # TODO: teya implement + # aggs: + # - "static/present" + # - "static/first" + # - "code/count" + # - "value/count" + # - "value/sum" + # - "value/sum_sqd" + # - "value/min" + # - "value/max" + integer: # choose integer value from a range [start, end, step] + min_code_inclusion_frequency: [10, 100, 10] diff --git a/src/MEDS_tabular_automl/describe_codes.py b/src/MEDS_tabular_automl/describe_codes.py new file mode 100644 index 0000000..de70682 --- /dev/null +++ b/src/MEDS_tabular_automl/describe_codes.py @@ -0,0 +1,198 @@ +from pathlib import Path + +import polars as pl +from omegaconf import DictConfig, OmegaConf + +from MEDS_tabular_automl.utils import DF_T, get_feature_names + + +def convert_to_df(freq_dict): + return pl.DataFrame([[col, freq] for col, freq in freq_dict.items()], schema=["code", "count"]) + + +def compute_feature_frequencies(cfg: DictConfig, shard_df: DF_T) -> list[str]: + """Generates a list of feature column names from the data within each shard based on specified + configurations. + + Parameters: + - cfg (DictConfig): Configuration dictionary specifying how features should be evaluated and aggregated. + - split_to_shard_df (dict): A dictionary of DataFrames, divided by data split (e.g., 'train', 'test'). + + Returns: + - tuple[list[str], dict]: A tuple containing a list of feature columns and a dictionary of code properties + identified during the evaluation. + + This function evaluates the properties of codes within training data and applies configured + aggregations to generate a comprehensive list of feature columns for modeling purposes. + Examples: + # >>> import polars as pl + # >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'], + # ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None], + # ... 'numerical_value': [1, None, 2, 2, None, None, 3]} + # >>> df = pl.DataFrame(data).lazy() + # >>> aggs = ['value/sum', 'code/count'] + # >>> compute_feature_frequencies(aggs, df) + # ['A/code', 'A/value', 'C/code', 'C/value'] + """ + static_df = shard_df.filter( + pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_null() + ) + static_code_freqs_df = static_df.group_by("code").agg(pl.count("code").alias("count")).collect() + static_code_freqs = { + row["code"] + "/static/present": row["count"] for row in static_code_freqs_df.iter_rows(named=True) + } + + static_value_df = static_df.filter(pl.col("numerical_value").is_not_null()) + static_value_freqs_df = ( + static_value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect() + ) + static_value_freqs = { + row["code"] + "/static/first": row["count"] for row in static_value_freqs_df.iter_rows(named=True) + } + + ts_df = shard_df.filter( + pl.col("patient_id").is_not_null() & pl.col("code").is_not_null() & pl.col("timestamp").is_not_null() + ) + code_freqs_df = ts_df.group_by("code").agg(pl.count("code").alias("count")).collect() + code_freqs = {row["code"] + "/code": row["count"] for row in code_freqs_df.iter_rows(named=True)} + + value_df = ts_df.filter(pl.col("numerical_value").is_not_null()) + value_freqs_df = value_df.group_by("code").agg(pl.count("numerical_value").alias("count")).collect() + value_freqs = {row["code"] + "/value": row["count"] for row in value_freqs_df.iter_rows(named=True)} + + combined_freqs = {**static_code_freqs, **static_value_freqs, **code_freqs, **value_freqs} + return convert_to_df(combined_freqs) + + +def convert_to_freq_dict(df: pl.LazyFrame) -> dict: + """Converts a DataFrame to a dictionary of frequencies. + + This function converts a DataFrame to a dictionary of frequencies, where the keys are the + column names and the values are dictionaries of code frequencies. + + Args: + - df (pl.DataFrame): The DataFrame to be converted. + + Returns: + - dict: A dictionary of frequencies, where the keys are the column names and the values are + dictionaries of code frequencies. + + Example: + # >>> import polars as pl + # >>> df = pl.DataFrame({ + # ... "code": [1, 2, 3, 4, 5], + # ... "value": [10, 20, 30, 40, 50] + # ... }) + # >>> convert_to_freq_dict(df) + # {'code': {1: 1, 2: 1, 3: 1, 4: 1, 5: 1}, 'value': {10: 1, 20: 1, 30: 1, 40: 1, 50: 1}} + """ + if not df.columns == ["code", "count"]: + raise ValueError(f"DataFrame must have columns 'code' and 'count', but has columns {df.columns}!") + return dict(df.collect().iter_rows()) + + +def get_feature_columns(fp): + return sorted(list(convert_to_freq_dict(pl.scan_parquet(fp)).keys())) + + +def get_feature_freqs(fp): + return convert_to_freq_dict(pl.scan_parquet(fp)) + + +def filter_to_codes( + allowed_codes: list[str] | None, + min_code_inclusion_frequency: int, + code_metadata_fp: Path, +): + """Returns intersection of allowed codes if they are specified, and filters to codes based on inclusion + frequency.""" + if allowed_codes is None: + allowed_codes = get_feature_columns(code_metadata_fp) + feature_freqs = get_feature_freqs(code_metadata_fp) + + code_freqs = { + code: freq + for code, freq in feature_freqs.items() + if (freq >= min_code_inclusion_frequency and code in set(allowed_codes)) + } + return sorted([code for code, freq in code_freqs.items() if freq >= min_code_inclusion_frequency]) + + +OmegaConf.register_new_resolver("filter_to_codes", filter_to_codes) + + +def clear_code_aggregation_suffix(code): + if code.endswith("/code"): + return code[:-5] + elif code.endswith("/value"): + return code[:-6] + elif code.endswith("/static/present"): + return code[:-15] + elif code.endswith("/static/first"): + return code[:-13] + + +def filter_parquet(fp, allowed_codes: list[str]): + """Loads Parquet with Polars and filters to allowed codes. + + Args: + fp: Path to the Meds cohort shard + allowed_codes: List of codes to filter to. + + Expect: + >>> from tempfile import NamedTemporaryFile + >>> fp = NamedTemporaryFile() + >>> pl.DataFrame({ + ... "code": ["A", "A", "A", "A", "D", "D", "E", "E"], + ... "timestamp": [None, None, "2021-01-01", "2021-01-01", None, None, "2021-01-03", "2021-01-04"], + ... "numerical_value": [1, None, 2, 2, None, 5, None, 3] + ... }).write_parquet(fp.name) + >>> filter_parquet(fp.name, ["A/code", "D/static/present", "E/code", "E/value"]).collect() + shape: (6, 3) + ┌──────┬────────────┬─────────────────┐ + │ code ┆ timestamp ┆ numerical_value │ + │ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ i64 │ + ╞══════╪════════════╪═════════════════╡ + │ A ┆ 2021-01-01 ┆ null │ + │ A ┆ 2021-01-01 ┆ null │ + │ D ┆ null ┆ null │ + │ D ┆ null ┆ null │ + │ E ┆ 2021-01-03 ┆ null │ + │ E ┆ 2021-01-04 ┆ 3 │ + └──────┴────────────┴─────────────────┘ + >>> fp.close() + """ + df = pl.scan_parquet(fp) + # Drop values that are rare + # Drop Rare Static Codes + static_present_feature_columns = [ + clear_code_aggregation_suffix(each) for each in get_feature_names("static/present", allowed_codes) + ] + static_first_feature_columns = [ + clear_code_aggregation_suffix(each) for each in get_feature_names("static/first", allowed_codes) + ] + code_feature_columns = [ + clear_code_aggregation_suffix(each) for each in get_feature_names("code/count", allowed_codes) + ] + value_feature_columns = [ + clear_code_aggregation_suffix(each) for each in get_feature_names("value/sum", allowed_codes) + ] + + is_static_code = pl.col("timestamp").is_null() + is_numeric_code = pl.col("numerical_value").is_not_null() + rare_static_code = is_static_code & ~pl.col("code").is_in(static_present_feature_columns) + rare_ts_code = ~is_static_code & ~pl.col("code").is_in(code_feature_columns) + rare_ts_value = ~is_static_code & ~pl.col("code").is_in(value_feature_columns) & is_numeric_code + rare_static_value = is_static_code & ~pl.col("code").is_in(static_first_feature_columns) & is_numeric_code + + # Remove rare numeric values by converting them to null + df = df.with_columns( + pl.when(rare_static_value | rare_ts_value) + .then(None) + .otherwise(pl.col("numerical_value")) + .alias("numerical_value") + ) + # Drop rows with rare codes + df = df.filter(~(rare_static_code | rare_ts_code)) + return df diff --git a/src/MEDS_tabular_automl/file_name.py b/src/MEDS_tabular_automl/file_name.py new file mode 100644 index 0000000..898d11e --- /dev/null +++ b/src/MEDS_tabular_automl/file_name.py @@ -0,0 +1,28 @@ +"""Help functions for getting file names and paths for MEDS tabular automl tasks.""" +from pathlib import Path + + +def list_subdir_files(dir: [Path | str], fmt: str): + return sorted(list(Path(dir).glob(f"**/*.{fmt}"))) + + +def get_task_specific_path(cfg, split, shard_num, window_size, agg): + return Path(cfg.input_dir) / split / f"{shard_num}" / f"{window_size}" / f"{agg}.npz" + + +def get_model_files(cfg, split: str, shard_num: int): + window_sizes = cfg.tabularization.window_sizes + aggs = cfg.tabularization.aggs + # Given a shard number, returns the model files + model_files = [] + for window_size in window_sizes: + for agg in aggs: + if agg.startswith("static"): + continue + else: + model_files.append(get_task_specific_path(cfg, split, shard_num, window_size, agg)) + for agg in aggs: + if agg.startswith("static"): + window_size = "none" + model_files.append(get_task_specific_path(cfg, split, shard_num, window_size, agg)) + return sorted(model_files) diff --git a/src/MEDS_tabular_automl/generate_static_features.py b/src/MEDS_tabular_automl/generate_static_features.py new file mode 100644 index 0000000..c2164c4 --- /dev/null +++ b/src/MEDS_tabular_automl/generate_static_features.py @@ -0,0 +1,182 @@ +"""This module provides functions for generating static representations of patient data for use in automated +machine learning models. It includes functionality to summarize measurements based on static features and then +transform them into a tabular format suitable for analysis. The module leverages the polars library for +efficient data manipulation. + +Functions: +- _summarize_static_measurements: Summarizes static measurements from a given DataFrame. +- get_flat_static_rep: Produces a tabular representation of static data features. +""" + +import numpy as np +import polars as pl +from loguru import logger +from scipy.sparse import coo_array, csr_array + +from MEDS_tabular_automl.utils import ( + DF_T, + STATIC_CODE_AGGREGATION, + STATIC_VALUE_AGGREGATION, + get_events_df, + get_feature_names, + get_unique_time_events_df, + parse_static_feature_column, +) + + +def convert_to_matrix(df, num_events, num_features): + """Converts a Polars DataFrame to a sparse matrix.""" + dense_matrix = df.drop("patient_id").collect().to_numpy() + data_list = [] + rows = [] + cols = [] + for row in range(dense_matrix.shape[0]): + for col in range(dense_matrix.shape[1]): + data = dense_matrix[row, col] + if (data is not None) and (data != 0): + data_list.append(data) + rows.append(row) + cols.append(col) + matrix = csr_array((data_list, (rows, cols)), shape=(num_events, num_features)) + return matrix + + +def get_sparse_static_rep(static_features, static_df, meds_df, feature_columns) -> coo_array: + """Merges static and time-series dataframes. + + This function merges the static and time-series dataframes based on the patient_id column. + + Args: + - feature_columns (List[str]): A list of feature columns to include in the merged dataframe. + - static_df (pd.DataFrame): A dataframe containing static features. + - ts_df (pd.DataFrame): A dataframe containing time-series features. + + Returns: + - pd.DataFrame: A merged dataframe containing static and time-series features. + """ + # Make static data sparse and merge it with the time-series data + logger.info("Make static data sparse and merge it with the time-series data") + # Check static_df is sorted and unique + assert static_df.select(pl.col("patient_id")).collect().to_series().is_sorted() + assert ( + static_df.select(pl.len()).collect().item() + == static_df.select(pl.col("patient_id").n_unique()).collect().item() + ) + meds_df = get_unique_time_events_df(get_events_df(meds_df, feature_columns)) + + # load static data as sparse matrix + static_matrix = convert_to_matrix( + static_df, num_events=meds_df.select(pl.len()).collect().item(), num_features=len(static_features) + ) + # Duplicate static matrix rows to match time-series data + events_per_patient = ( + meds_df.select(pl.col("patient_id").value_counts()) + .unnest("patient_id") + .sort(by="patient_id") + .select(pl.col("count")) + .collect() + .to_series() + ) + reindex_slices = np.repeat(range(len(events_per_patient)), events_per_patient) + static_matrix = static_matrix[reindex_slices, :] + return coo_array(static_matrix) + + +def summarize_static_measurements( + agg: str, + feature_columns: list[str], + df: DF_T, +) -> pl.LazyFrame: + """Aggregates static measurements for feature columns that are marked as 'present' or 'first'. + + Parameters: + - feature_columns (list[str]): List of feature column identifiers that are specifically marked + for staticanalysis. + - df (DF_T): Data frame from which features will be extracted and summarized. + + Returns: + - pl.LazyFrame: A LazyFrame containing the summarized data pivoted by 'patient_id' + for each static feature. + + This function first filters for features that need to be recorded as the first occurrence + or simply as present, then performs a pivot to reshape the data for each patient, providing + a tabular format where each row represents a patient and each column represents a static feature. + """ + if agg == STATIC_VALUE_AGGREGATION: + static_features = get_feature_names(agg=agg, feature_columns=feature_columns) + # Handling 'first' static values + static_first_codes = [parse_static_feature_column(c)[0] for c in static_features] + code_subset = df.filter(pl.col("code").is_in(static_first_codes)) + first_code_subset = code_subset.group_by(pl.col("patient_id")).first().collect() + static_value_pivot_df = first_code_subset.pivot( + index=["patient_id"], columns=["code"], values=["numerical_value"], aggregate_function=None + ) + # rename code to feature name + remap_cols = { + input_name: output_name + for input_name, output_name in zip(static_first_codes, static_features) + if input_name in static_value_pivot_df.columns + } + static_value_pivot_df = static_value_pivot_df.select( + *["patient_id"], *[pl.col(k).alias(v).cast(pl.Boolean) for k, v in remap_cols.items()] + ).sort(by="patient_id") + # pivot can be faster: https://stackoverflow.com/questions/73522017/replacing-a-pivot-with-a-lazy-groupby-operation # noqa: E501 + # TODO: consider casting with .cast(pl.Float32)) + return static_value_pivot_df + elif agg == STATIC_CODE_AGGREGATION: + static_features = get_feature_names(agg=agg, feature_columns=feature_columns) + # Handling 'present' static indicators + static_present_codes = [parse_static_feature_column(c)[0] for c in static_features] + static_present_pivot_df = ( + df.select(*["patient_id", "code"]) + .filter(pl.col("code").is_in(static_present_codes)) + .with_columns(pl.lit(True).alias("__indicator")) + .collect() + .pivot( + index=["patient_id"], + columns=["code"], + values="__indicator", + aggregate_function=None, + ) + .sort(by="patient_id") + ) + remap_cols = { + input_name: output_name + for input_name, output_name in zip(static_present_codes, static_features) + if input_name in static_present_pivot_df.columns + } + # rename columns to final feature names + static_present_pivot_df = static_present_pivot_df.select( + *["patient_id"], *[pl.col(k).alias(v).cast(pl.Boolean) for k, v in remap_cols.items()] + ) + return static_present_pivot_df + else: + raise ValueError(f"Invalid aggregation type: {agg}") + + +def get_flat_static_rep( + agg: str, + feature_columns: list[str], + shard_df: DF_T, +) -> coo_array: + """Produces a raw representation for static data from a specified shard DataFrame. + + Parameters: + - feature_columns (list[str]): List of feature columns to include in the static representation. + - shard_df (DF_T): The shard DataFrame containing patient data. + + Returns: + - pl.LazyFrame: A LazyFrame that includes all static features for the data provided. + + This function selects the appropriate static features, summarizes them using + _summarize_static_measurements, and then normalizes the resulting data to ensure it is + suitable for further analysis or machine learning tasks. + """ + static_features = get_feature_names(agg=agg, feature_columns=feature_columns) + static_measurements = summarize_static_measurements(agg, static_features, df=shard_df) + # convert to sparse_matrix + matrix = get_sparse_static_rep(static_features, static_measurements.lazy(), shard_df, feature_columns) + assert matrix.shape[1] == len( + static_features + ), f"Expected {len(static_features)} features, got {matrix.shape[1]}" + return matrix diff --git a/src/MEDS_tabular_automl/generate_summarized_reps.py b/src/MEDS_tabular_automl/generate_summarized_reps.py new file mode 100644 index 0000000..254a381 --- /dev/null +++ b/src/MEDS_tabular_automl/generate_summarized_reps.py @@ -0,0 +1,272 @@ +import numpy as np +import pandas as pd +import polars as pl + +pl.enable_string_cache() +from loguru import logger +from scipy.sparse import coo_array, csr_array, sparray + +from MEDS_tabular_automl.generate_ts_features import get_feature_names, get_flat_ts_rep +from MEDS_tabular_automl.utils import CODE_AGGREGATIONS, VALUE_AGGREGATIONS, load_tqdm + + +def sparse_aggregate(sparse_matrix, agg): + if agg == "sum": + merged_matrix = sparse_matrix.sum(axis=0, dtype=sparse_matrix.dtype) + elif agg == "min": + merged_matrix = sparse_matrix.min(axis=0) + elif agg == "max": + merged_matrix = sparse_matrix.max(axis=0) + elif agg == "sum_sqd": + merged_matrix = sparse_matrix.power(2).sum(axis=0, dtype=sparse_matrix.dtype) + elif agg == "count": + merged_matrix = sparse_matrix.getnnz(axis=0) + else: + raise ValueError(f"Aggregation method '{agg}' not implemented.") + return merged_matrix + + +def get_rolling_window_indicies(index_df, window_size): + """Get the indices for the rolling windows.""" + if window_size == "full": + timedelta = pd.Timedelta(150 * 52, unit="W") # just use 150 years as time delta + else: + timedelta = pd.Timedelta(window_size) + return ( + index_df.with_row_index("index") + .rolling(index_column="timestamp", period=timedelta, group_by="patient_id") + .agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")]) + .select(pl.col("min_index", "max_index")) + .collect() + ) + + +def aggregate_matrix(windows, matrix, agg, num_features, use_tqdm=False): + """Aggregate the matrix based on the windows.""" + tqdm = load_tqdm(use_tqdm) + agg = agg.split("/")[-1] + dtype = np.float32 + matrix = csr_array(matrix.astype(dtype)) + if agg.startswith("sum"): + out_dtype = np.float32 + else: + out_dtype = np.int32 + data, row, col = [], [], [] + for i, window in tqdm(enumerate(windows.iter_rows(named=True)), total=len(windows)): + min_index = window["min_index"] + max_index = window["max_index"] + subset_matrix = matrix[min_index : max_index + 1, :] + agg_matrix = sparse_aggregate(subset_matrix, agg).astype(out_dtype) + if isinstance(agg_matrix, np.ndarray): + nozero_ind = np.nonzero(agg_matrix)[0] + col.append(nozero_ind) + data.append(agg_matrix[nozero_ind]) + row.append(np.repeat(np.array(i, dtype=np.int32), len(nozero_ind))) + elif isinstance(agg_matrix, coo_array): + col.append(agg_matrix.col) + data.append(agg_matrix.data) + row.append(agg_matrix.row) + else: + raise TypeError(f"Invalid matrix type {type(agg_matrix)}") + row = np.concatenate(row) + out_matrix = coo_array( + (np.concatenate(data), (row, np.concatenate(col))), + dtype=out_dtype, + shape=(windows.shape[0], num_features), + ) + return csr_array(out_matrix) + + +def compute_agg(index_df, matrix: sparray, window_size: str, agg: str, num_features: int, use_tqdm=False): + """Applies aggreagtion to dataframe. + + Dataframe is expected to only have the relevant columns for aggregating + It should have the patient_id and timestamp columns, and then only code columns + if agg is a code aggregation or only value columns if it is a value aggreagation. + + Example: + >>> from datetime import datetime + >>> df = pd.DataFrame({ + ... "patient_id": [1, 1, 1, 2], + ... "timestamp": [ + ... datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2020, 1, 3), datetime(2021, 1, 4) + ... ], + ... "A/code": [1, 1, 0, 0], + ... "B/code": [0, 0, 1, 1], + ... "C/code": [0, 0, 0, 0], + ... }) + >>> output = compute_agg(df, "1d", "code/count") + >>> output + 1d/A/code/count 1d/B/code/count 1d/C/code/count timestamp patient_id + 0 1 0 0 2021-01-01 1 + 1 2 0 0 2021-01-01 1 + 2 0 1 0 2020-01-01 1 + 0 0 1 0 2021-01-04 2 + >>> output.dtypes + 1d/A/code/count Sparse[int64, 0] + 1d/B/code/count Sparse[int64, 0] + 1d/C/code/count Sparse[int64, 0] + timestamp datetime64[ns] + patient_id int64 + dtype: object + """ + group_df = ( + index_df.with_row_index("index") + .group_by(["patient_id", "timestamp"], maintain_order=True) + .agg([pl.col("index").min().alias("min_index"), pl.col("index").max().alias("max_index")]) + .collect() + ) + index_df = group_df.lazy().select(pl.col("patient_id", "timestamp")) + windows = group_df.select(pl.col("min_index", "max_index")) + logger.info("Step 1.5: Running sparse aggregation.") + matrix = aggregate_matrix(windows, matrix, agg, num_features, use_tqdm) + logger.info("Step 2: computing rolling windows and aggregating.") + windows = get_rolling_window_indicies(index_df, window_size) + logger.info("Starting final sparse aggregations.") + matrix = aggregate_matrix(windows, matrix, agg, num_features, use_tqdm) + return matrix + + +def _generate_summary( + ts_columns: list[str], + index_df: pd.DataFrame, + matrix: sparray, + window_size: str, + agg: str, + num_features, + use_tqdm=False, +) -> pl.LazyFrame: + """Generate a summary of the data frame for a given window size and aggregation. + + Args: + - df (DF_T): The data frame to summarize. + - window_size (str): The window size to use for the summary. + - agg (str): The aggregation to apply to the data frame. + + Returns: + - pl.LazyFrame: The summarized data frame. + + Expect: + >>> from datetime import datetime + >>> wide_df = pd.DataFrame({ + ... "patient_id": [1, 1, 1, 2], + ... "timestamp": [ + ... datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2020, 1, 3), datetime(2021, 1, 4) + ... ], + ... "A/code": [1, 1, 0, 0], + ... "B/code": [0, 0, 1, 1], + ... "C/code": [0, 0, 0, 0], + ... "A/value": [1, 2, 0, 0], + ... "B/value": [0, 0, 2, 2], + ... "C/value": [0, 0, 0, 0], + ... }) + >>> _generate_summary(wide_df, "full", "value/sum") + full/A/value/count full/B/value/count full/C/value/count timestamp patient_id + 0 1 0 0 2021-01-01 1 + 1 3 0 0 2021-01-01 1 + 2 3 2 0 2021-01-01 1 + 0 0 2 0 2021-01-04 2 + """ + if agg not in CODE_AGGREGATIONS + VALUE_AGGREGATIONS: + raise ValueError( + f"Invalid aggregation: {agg}. Valid options are: {CODE_AGGREGATIONS + VALUE_AGGREGATIONS}" + ) + out_matrix = compute_agg(index_df, matrix, window_size, agg, num_features, use_tqdm=use_tqdm) + return out_matrix + + +def generate_summary( + feature_columns: list[str], index_df: pl.LazyFrame, matrix: sparray, window_size, agg: str, use_tqdm=False +) -> pl.LazyFrame: + """Generate a summary of the data frame for given window sizes and aggregations. + + This function processes a dataframe to apply specified aggregations over defined window sizes. + It then joins the resulting frames on 'patient_id' and 'timestamp', and ensures all specified + feature columns exist in the final output, adding missing ones with default values. + + Args: + feature_columns (list[str]): List of all feature columns that must exist in the final output. + df (list[pl.LazyFrame]): The input dataframes to process, expected to be length 2 list with code_df + (pivoted shard with binary presence of codes) and value_df (pivoted shard with numerical values + for each code). + window_sizes (list[str]): List of window sizes to apply for summarization. + aggregations (list[str]): List of aggregations to perform within each window size. + + Returns: + pl.LazyFrame: A LazyFrame containing the summarized data with all required features present. + + Expect: + >>> from datetime import date + >>> wide_df = pd.DataFrame({"patient_id": [1, 1, 1, 2], + ... "A/code": [1, 1, 0, 0], + ... "B/code": [0, 0, 1, 1], + ... "A/value": [1, 2, 3, None], + ... "B/value": [None, None, None, 4.0], + ... "timestamp": [date(2021, 1, 1), date(2021, 1, 1),date(2020, 1, 3), date(2021, 1, 4)], + ... }) + >>> wide_df['timestamp'] = pd.to_datetime(wide_df['timestamp']) + >>> for col in ["A/code", "B/code", "A/value", "B/value"]: + ... wide_df[col] = pd.arrays.SparseArray(wide_df[col]) + >>> feature_columns = ["A/code/count", "B/code/count", "A/value/sum", "B/value/sum"] + >>> aggregations = ["code/count", "value/sum"] + >>> window_sizes = ["full", "1d"] + >>> generate_summary(feature_columns, wide_df, window_sizes, aggregations)[ + ... ["1d/A/code/count", "full/B/code/count", "full/B/value/sum"]] + 1d/A/code/count full/B/code/count full/B/value/sum + 0 NaN 1.0 0 + 1 NaN 1.0 0 + 2 NaN 1.0 0 + 0 NaN 1.0 0 + 0 NaN NaN 0 + 1 NaN NaN 0 + 2 NaN NaN 0 + 0 NaN NaN 0 + 0 0 NaN 0 + 1 1.0 NaN 0 + 2 2.0 NaN 0 + 0 0 NaN 0 + 0 NaN NaN 0 + 1 NaN NaN 0 + 2 NaN NaN 0 + 0 NaN NaN 0 + """ + assert len(feature_columns), "feature_columns must be a non-empty list" + ts_columns = get_feature_names(agg, feature_columns) + # Generate summaries for each window size and aggregation + code_type, _ = agg.split("/") + # only iterate through code_types that exist in the dataframe columns + assert any([c.endswith(code_type) for c in ts_columns]) + logger.info( + f"Generating aggregation {agg} for window_size {window_size}, with {len(ts_columns)} columns." + ) + out_matrix = _generate_summary( + ts_columns, index_df, matrix, window_size, agg, len(ts_columns), use_tqdm=use_tqdm + ) + return out_matrix + + +if __name__ == "__main__": + import json + from pathlib import Path + + feature_columns = json.load( + open( + Path("/storage/shared/meds_tabular_ml/ebcl_dataset/processed/tabularize") / "feature_columns.json" + ) + ) + df = pl.scan_parquet( + Path("/storage/shared/meds_tabular_ml/ebcl_dataset/processed") + / "final_cohort" + / "train" + / "2.parquet" + ) + agg = "code/count" + index_df, sparse_matrix = get_flat_ts_rep(agg, feature_columns, df) + generate_summary( + feature_columns=feature_columns, + index_df=index_df, + matrix=sparse_matrix, + window_size="full", + agg=agg, + use_tqdm=True, + ) diff --git a/src/MEDS_tabular_automl/generate_ts_features.py b/src/MEDS_tabular_automl/generate_ts_features.py new file mode 100644 index 0000000..c4d244e --- /dev/null +++ b/src/MEDS_tabular_automl/generate_ts_features.py @@ -0,0 +1,160 @@ +import warnings + +import numpy as np +import pandas as pd +import polars as pl +from loguru import logger +from scipy.sparse import csr_array + +from MEDS_tabular_automl.utils import ( + CODE_AGGREGATIONS, + DF_T, + VALUE_AGGREGATIONS, + get_events_df, + get_feature_names, +) + +warnings.simplefilter(action="ignore", category=FutureWarning) + + +def feature_name_to_code(feature_name: str) -> str: + """Converts a feature name to a code name.""" + return "/".join(feature_name.split("/")[:-1]) + + +def get_long_code_df(df, ts_columns): + """Pivots the codes data frame to a long format one-hot rep for time series data.""" + column_to_int = {feature_name_to_code(col): i for i, col in enumerate(ts_columns)} + rows = range(df.select(pl.len()).collect().item()) + cols = ( + df.with_columns(pl.col("code").cast(str).replace(column_to_int).cast(int).alias("code_index")) + .select("code_index") + .collect() + .to_series() + .to_numpy() + ) + assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type" + data = np.ones(df.select(pl.len()).collect().item(), dtype=np.bool_) + return data, (rows, cols) + + +def get_long_value_df(df, ts_columns): + """Pivots the numerical value data frame to a long format for time series data.""" + column_to_int = {feature_name_to_code(col): i for i, col in enumerate(ts_columns)} + value_df = ( + df.with_row_index("index").drop_nulls("numerical_value").filter(pl.col("code").is_in(ts_columns)) + ) + rows = value_df.select(pl.col("index")).collect().to_series().to_numpy() + cols = ( + value_df.with_columns(pl.col("code").cast(str).replace(column_to_int).cast(int).alias("value_index")) + .select("value_index") + .collect() + .to_series() + .to_numpy() + ) + assert np.issubdtype(cols.dtype, np.number), "numerical_value must be a numerical type" + data = value_df.select(pl.col("numerical_value")).collect().to_series().to_numpy() + return data, (rows, cols) + + +def summarize_dynamic_measurements( + agg: str, + ts_columns: list[str], + df: pd.DataFrame, +) -> pd.DataFrame: + """Summarize dynamic measurements for feature columns that are marked as 'dynamic'. + + Args: + - ts_columns (list[str]): List of feature column identifiers that are specifically marked for dynamic + analysis. + - shard_df (DF_T): Data frame from which features will be extracted and summarized. + + Returns: + - pl.LazyFrame: A summarized data frame containing the dynamic features. + + Example: + >>> data = {'patient_id': [1, 1, 1, 2], + ... 'code': ['A', 'A', 'B', 'B'], + ... 'timestamp': ['2021-01-01', '2021-01-01', '2020-01-01', '2021-01-04'], + ... 'numerical_value': [1, 2, 2, 2]} + >>> df = pd.DataFrame(data) + >>> ts_columns = ['A', 'B'] + >>> long_df = summarize_dynamic_measurements(ts_columns, df) + >>> long_df.head() + patient_id timestamp A/value B/value A/code B/code + 0 1 2021-01-01 1 0 1 0 + 1 1 2021-01-01 2 0 1 0 + 2 1 2020-01-01 0 2 0 1 + 3 2 2021-01-04 0 2 0 1 + >>> long_df.shape + (4, 6) + >>> long_df = summarize_dynamic_measurements(ts_columns, df[df.code == "A"]) + >>> long_df + patient_id timestamp A/value B/value A/code B/code + 0 1 2021-01-01 1 0 1 0 + 1 1 2021-01-01 2 0 1 0 + """ + logger.info("Generating Sparse matrix for Time Series Features") + id_cols = ["patient_id", "timestamp"] + + # Confirm dataframe is sorted + check_df = df.select(pl.col(id_cols)) + assert check_df.sort(by=id_cols).collect().equals(check_df.collect()), "data frame must be sorted" + + # Generate sparse matrix + if agg in CODE_AGGREGATIONS: + code_df = df.drop(*(id_cols + ["numerical_value"])) + data, (rows, cols) = get_long_code_df(code_df, ts_columns) + elif agg in VALUE_AGGREGATIONS: + value_df = df.drop(*id_cols) + data, (rows, cols) = get_long_value_df(value_df, ts_columns) + + sp_matrix = csr_array( + (data, (rows, cols)), + shape=(df.select(pl.len()).collect().item(), len(ts_columns)), + ) + return df.select(pl.col(id_cols)), sp_matrix + + +def get_flat_ts_rep( + agg: str, + feature_columns: list[str], + shard_df: DF_T, +) -> pl.LazyFrame: + """Produce a flat time series representation from a given data frame, focusing on non-static feature + columns. + + This function filters the given data frame for non-static features based on the 'feature_columns' + provided and generates a flat time series representation using these dynamic features. The resulting + data frame includes both codes and values transformed and aggregated appropriately. + + Args: + feature_columns (list[str]): A list of column identifiers that determine which features are considered + for dynamic analysis. + shard_df (DF_T): The data frame containing time-stamped data from which features will be extracted + and summarized. + + Returns: + pl.LazyFrame: A LazyFrame consisting of the processed time series data, combining both code and value + representations. + + Example: + >>> feature_columns = ['A/value', 'A/code', 'B/value', 'B/code', + ... "C/value", "C/code", "A/static/present"] + >>> data = {'patient_id': [1, 1, 1, 2, 2, 2], + ... 'code': ['A', 'A', 'B', 'B', 'C', 'C'], + ... 'timestamp': ['2021-01-01', '2021-01-01', '2020-01-01', '2021-01-04', None, None], + ... 'numerical_value': [1, 2, 2, 2, 3, 4]} + >>> df = pl.DataFrame(data).lazy() + >>> pivot_df = get_flat_ts_rep(feature_columns, df) + >>> pivot_df + patient_id timestamp A/value B/value C/value A/code B/code C/code + 0 1 2021-01-01 1 0 0 1 0 0 + 1 1 2021-01-01 2 0 0 1 0 0 + 2 1 2020-01-01 0 2 0 0 1 0 + 3 2 2021-01-04 0 2 0 0 1 0 + """ + # Remove codes not in training set + shard_df = get_events_df(shard_df, feature_columns) + ts_columns = get_feature_names(agg, feature_columns) + return summarize_dynamic_measurements(agg, ts_columns, shard_df) diff --git a/src/MEDS_tabular_automl/mapper.py b/src/MEDS_tabular_automl/mapper.py new file mode 100644 index 0000000..34275b8 --- /dev/null +++ b/src/MEDS_tabular_automl/mapper.py @@ -0,0 +1,278 @@ +"""Basic utilities for parallelizable map operations on sharded MEDS datasets with caching and locking.""" + +import json +import shutil +from collections.abc import Callable +from datetime import datetime +from pathlib import Path + +from loguru import logger + +LOCK_TIME_FMT = "%Y-%m-%dT%H:%M:%S.%f" + + +def get_earliest_lock(cache_directory: Path) -> datetime | None: + """Returns the earliest start time of any lock file present in a cache directory, or None if none exist. + + Args: + cache_directory: The cache directory to check for the presence of a lock file. + + Examples: + >>> import tempfile + >>> directory = tempfile.TemporaryDirectory() + >>> root = Path(directory.name) + >>> empty_directory = root / "cache_empty" + >>> empty_directory.mkdir(exist_ok=True, parents=True) + >>> cache_directory = root / "cache_with_locks" + >>> locks_directory = cache_directory / "locks" + >>> locks_directory.mkdir(exist_ok=True, parents=True) + >>> time_1 = datetime(2021, 1, 1) + >>> time_1_str = time_1.strftime(LOCK_TIME_FMT) # "2021-01-01T00:00:00.000000" + >>> lock_fp_1 = locks_directory / f"{time_1_str}.json" + >>> _ = lock_fp_1.write_text(json.dumps({"start": time_1_str})) + >>> time_2 = datetime(2021, 1, 2, 3, 4, 5) + >>> time_2_str = time_2.strftime(LOCK_TIME_FMT) # "2021-01-02T03:04:05.000000" + >>> lock_fp_2 = locks_directory / f"{time_2_str}.json" + >>> _ = lock_fp_2.write_text(json.dumps({"start": time_2_str})) + >>> get_earliest_lock(cache_directory) + datetime.datetime(2021, 1, 1, 0, 0) + >>> get_earliest_lock(empty_directory) is None + True + >>> lock_fp_1.unlink() + >>> get_earliest_lock(cache_directory) + datetime.datetime(2021, 1, 2, 3, 4, 5) + >>> directory.cleanup() + """ + locks_directory = cache_directory / "locks" + + lock_times = [ + datetime.strptime(json.loads(lock_fp.read_text())["start"], LOCK_TIME_FMT) + for lock_fp in locks_directory.glob("*.json") + ] + + return min(lock_times) if lock_times else None + + +def register_lock(cache_directory: Path) -> tuple[datetime, Path]: + """Register a lock file in a cache directory. + + Args: + cache_directory: The cache directory to register a lock file in. + + Examples: + >>> import tempfile + >>> directory = tempfile.TemporaryDirectory() + >>> root = Path(directory.name) + >>> cache_directory = root / "cache_with_locks" + >>> lock_time, lock_fp = register_lock(cache_directory) + >>> assert (datetime.now() - lock_time).total_seconds() < 1, "Lock time should be ~ now." + >>> lock_fp.is_file() + True + >>> lock_fp.read_text() == f'{{"start": "{lock_time.strftime(LOCK_TIME_FMT)}"}}' + True + >>> directory.cleanup() + """ + + lock_directory = cache_directory / "locks" + lock_directory.mkdir(exist_ok=True, parents=True) + + lock_time = datetime.now() + lock_fp = lock_directory / f"{lock_time.strftime(LOCK_TIME_FMT)}.json" + lock_fp.write_text(json.dumps({"start": lock_time.strftime(LOCK_TIME_FMT)})) + return lock_time, lock_fp + + +def wrap[ + DF_T +]( + in_fp: Path, + out_fp: Path, + read_fn: Callable[[Path], DF_T], + write_fn: Callable[[DF_T, Path], None], + *transform_fns: Callable[[DF_T], DF_T], + cache_intermediate: bool = True, + clear_cache_on_completion: bool = True, + do_overwrite: bool = False, + do_return: bool = False, +) -> tuple[bool, DF_T | None]: + """Wrap a series of file-in file-out map transformations on a dataframe with caching and locking. + + Args: + in_fp: The file path of the input dataframe. Must exist and be readable via `read_fn`. + out_fp: Output file path. The parent directory will be created if it does not exist. If this file + already exists, it will be deleted before any computations are done if `do_overwrite=True`, which + can result in data loss if the transformation functions do not complete successfully on + intermediate steps. If `do_overwrite` is `False` and this file exists, the function will use the + `read_fn` to read the file and return the dataframe directly. + read_fn: Function that reads the dataframe from a file. This must take as input a Path object and + return a dataframe of (generic) type DF_T. Ideally, this read function can make use of lazy + loading to further accelerate unnecessary reads when resuming from intermediate cached steps. + write_fn: Function that writes the dataframe to a file. This must take as input a dataframe of + (generic) type DF_T and a Path object, and will write the dataframe to that file. + transform_fns: A series of functions that transform the dataframe. Each function must take as input + a dataframe of (generic) type DF_T and return a dataframe of (generic) type DF_T. The functions + will be applied in the passed order. + cache_intermediate: If True, intermediate outputs of the transformations will be cached in a hidden + directory in the same parent directory as `out_fp` of the form + `{out_fp.parent}/.{out_fp.stem}_cache`. This can be useful for debugging and resuming from + intermediate steps when nontrivial transformations are composed. Cached files will be named + `step_{i}.output` where `i` is the index of the transformation function in `transform_fns`. **Note + that if you change the order of the transformations, the cache will be no longer valid but the + system will _not_ automatically delete the cache!**. This is `True` by default. + If `do_overwrite=True`, any prior individual cache files that are detected during the run will be + deleted before their corresponding step is run. If `do_overwrite=False` and a cache file exists, + that step of the transformation will be skipped and the cache file will be read directly. + clear_cache_on_completion: If True, the cache directory will be deleted after the final output is + written. This is `True` by default. + do_overwrite: If True, the output file will be overwritten if it already exists. This is `False` by + default. + do_return: If True, the final dataframe will be returned. This is `False` by default. + + Returns: + The dataframe resulting from the transformations applied in sequence to the dataframe stored in + `in_fp`. + + Examples: + >>> import polars as pl + >>> import tempfile + >>> directory = tempfile.TemporaryDirectory() + >>> root = Path(directory.name) + >>> # For this example we'll use a simple CSV file, but in practice we *strongly* recommend using + >>> # Parquet files for performance reasons. + >>> in_fp = root / "input.csv" + >>> out_fp = root / "output.csv" + >>> in_df = pl.DataFrame({"a": [1, 3, 3], "b": [2, 4, 5], "c": [3, -1, 6]}) + >>> in_df.write_csv(in_fp) + >>> read_fn = pl.read_csv + >>> write_fn = pl.DataFrame.write_csv + >>> transform_fns = [ + ... lambda df: df.with_columns(pl.col("c") * 2), + ... lambda df: df.filter(pl.col("c") > 4) + ... ] + >>> result_computed = wrap(in_fp, out_fp, read_fn, write_fn, *transform_fns, do_return=False) + >>> assert result_computed + >>> print(out_fp.read_text()) + a,b,c + 1,2,6 + 3,5,12 + + >>> out_fp.unlink() + >>> cache_directory = root / f".output_cache" + >>> assert not cache_directory.is_dir() + >>> transform_fns = [ + ... lambda df: df.with_columns(pl.col("c") * 2), + ... lambda df: df.filter(pl.col("d") > 4) + ... ] + >>> wrap(in_fp, out_fp, read_fn, write_fn, *transform_fns) + Traceback (most recent call last): + ... + polars.exceptions.ColumnNotFoundError: unable to find column "d"; valid columns: ["a", "b", "c"] + >>> assert cache_directory.is_dir() + >>> cache_fp = cache_directory / "step_0.output" + >>> pl.read_csv(cache_fp) + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ 1 ┆ 2 ┆ 6 │ + │ 3 ┆ 4 ┆ -2 │ + │ 3 ┆ 5 ┆ 12 │ + └─────┴─────┴─────┘ + >>> shutil.rmtree(cache_directory) + >>> lock_dir = cache_directory / "locks" + >>> assert not lock_dir.exists() + >>> def lock_dir_checker_fn(df: pl.DataFrame) -> pl.DataFrame: + ... print(f"Lock dir exists? {lock_dir.exists()}") + ... return df + >>> result_computed, out_df = wrap( + ... in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn, do_return=True + ... ) + Lock dir exists? True + >>> assert result_computed + >>> out_df + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ 1 ┆ 2 ┆ 3 │ + │ 3 ┆ 4 ┆ -1 │ + │ 3 ┆ 5 ┆ 6 │ + └─────┴─────┴─────┘ + >>> directory.cleanup() + """ + + if out_fp.is_file(): + if do_overwrite: + logger.info(f"Deleting existing {out_fp} as do_overwrite={do_overwrite}.") + out_fp.unlink() + else: + logger.info(f"{out_fp} exists; reading directly and returning.") + if do_return: + return True, read_fn(out_fp) + else: + return True + + cache_directory = out_fp.parent / f".{out_fp.stem}_cache" + cache_directory.mkdir(exist_ok=True, parents=True) + + earliest_lock_time = get_earliest_lock(cache_directory) + if earliest_lock_time is not None: + logger.info(f"{out_fp} is in progress as of {earliest_lock_time}. Returning.") + return False, None if do_return else False + + st_time, lock_fp = register_lock(cache_directory) + + logger.info(f"Registered lock at {st_time}. Double checking no earlier locks have been registered.") + earliest_lock_time = get_earliest_lock(cache_directory) + if earliest_lock_time < st_time: + logger.info(f"Earlier lock found at {earliest_lock_time}. Deleting current lock and returning.") + lock_fp.unlink() + return False, None if do_return else False + + logger.info(f"Reading input dataframe from {in_fp}") + df = read_fn(in_fp) + logger.info("Read dataset") + + try: + for i, transform_fn in enumerate(transform_fns): + cache_fp = cache_directory / f"step_{i}.output" + + st_time_step = datetime.now() + if cache_fp.is_file(): + if do_overwrite: + logger.info( + f"Deleting existing cached output for step {i} " f"as do_overwrite={do_overwrite}" + ) + cache_fp.unlink() + else: + logger.info(f"Reading cached output for step {i}") + df = read_fn(cache_fp) + else: + df = transform_fn(df) + + if cache_intermediate and i < len(transform_fns) - 1: + logger.info(f"Writing intermediate output for step {i} to {cache_fp}") + write_fn(df, cache_fp) + logger.info(f"Completed step {i} in {datetime.now() - st_time_step}") + + logger.info(f"Writing final output to {out_fp}") + write_fn(df, out_fp) + logger.info(f"Succeeded in {datetime.now() - st_time}") + if clear_cache_on_completion: + logger.info(f"Clearing cache directory {cache_directory}") + shutil.rmtree(cache_directory) + else: + logger.info(f"Leaving cache directory {cache_directory}, but clearing lock at {lock_fp}") + lock_fp.unlink() + if do_return: + return True, df + else: + return True + except Exception as e: + logger.warning(f"Clearing lock due to Exception {e} at {lock_fp} after {datetime.now() - st_time}") + lock_fp.unlink() + raise e diff --git a/src/MEDS_tabular_automl/scripts/__init__.py b/src/MEDS_tabular_automl/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MEDS_tabular_automl/scripts/cache_task.py b/src/MEDS_tabular_automl/scripts/cache_task.py new file mode 100644 index 0000000..5f0aff3 --- /dev/null +++ b/src/MEDS_tabular_automl/scripts/cache_task.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python + +"""Aggregates time-series data for feature columns across different window sizes.""" +from importlib.resources import files +from pathlib import Path + +import hydra +import numpy as np +import polars as pl +import scipy.sparse as sp +from omegaconf import DictConfig + +from MEDS_tabular_automl.file_name import list_subdir_files +from MEDS_tabular_automl.mapper import wrap as rwlock_wrap +from MEDS_tabular_automl.utils import ( + CODE_AGGREGATIONS, + STATIC_CODE_AGGREGATION, + STATIC_VALUE_AGGREGATION, + VALUE_AGGREGATIONS, + get_shard_prefix, + hydra_loguru_init, + load_matrix, + load_tqdm, + write_df, +) + +config_yaml = files("MEDS_tabular_automl").joinpath("configs/task_specific_caching.yaml") +if not config_yaml.is_file(): + raise FileNotFoundError("Core configuration not successfully installed!") + + +VALID_AGGREGATIONS = [ + *VALUE_AGGREGATIONS, + *CODE_AGGREGATIONS, + STATIC_CODE_AGGREGATION, + STATIC_VALUE_AGGREGATION, +] + + +def generate_row_cached_matrix(matrix, label_df): + """Generates row-cached matrix for a given matrix and label_df.""" + label_len = label_df.select(pl.len()).collect().item() + if not matrix.shape[0] == label_len: + raise ValueError( + f"Matrix and label_df must have the same number of rows: {matrix.shape[0]} != {label_len}" + ) + csr = sp.csr_array(matrix) + valid_ids = label_df.select(pl.col("event_id")).collect().to_series().to_numpy() + csr = csr[valid_ids, :] + return sp.coo_array(csr) + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) +def main( + cfg: DictConfig, +): + """Performs row splicing of tabularized data for a specific task.""" + iter_wrapper = load_tqdm(cfg.tqdm) + if not cfg.loguru_init: + hydra_loguru_init() + # Produce ts representation + + # shuffle tasks + tabularization_tasks = list_subdir_files(cfg.input_dir, "npz") + np.random.shuffle(tabularization_tasks) + + # iterate through them + for data_fp in iter_wrapper(tabularization_tasks): + # parse as time series agg + split, shard_num, window_size, code_type, agg_name = Path(data_fp).with_suffix("").parts[-5:] + label_fp = Path(cfg.input_label_dir) / split / f"{shard_num}.parquet" + out_fp = (Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, data_fp)).with_suffix(".npz") + assert label_fp.exists(), f"Output file {label_fp} does not exist." + + def read_fn(fps): + matrix_fp, label_fp = fps + return load_matrix(matrix_fp), pl.scan_parquet(label_fp) + + def compute_fn(shard_dfs): + matrix, label_df = shard_dfs + cache_matrix = generate_row_cached_matrix(matrix, label_df) + return cache_matrix + + def write_fn(cache_matrix, out_fp): + write_df(cache_matrix, out_fp, do_overwrite=cfg.do_overwrite) + + in_fps = [data_fp, label_fp] + rwlock_wrap( + in_fps, + out_fp, + read_fn, + write_fn, + compute_fn, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_tabular_automl/scripts/describe_codes.py b/src/MEDS_tabular_automl/scripts/describe_codes.py new file mode 100644 index 0000000..034244a --- /dev/null +++ b/src/MEDS_tabular_automl/scripts/describe_codes.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +"""This Python script, stores the configuration parameters and feature columns used in the output.""" +from collections import defaultdict +from importlib.resources import files +from pathlib import Path + +import hydra +import numpy as np +import polars as pl +from loguru import logger +from omegaconf import DictConfig + +from MEDS_tabular_automl.describe_codes import ( + compute_feature_frequencies, + convert_to_df, + convert_to_freq_dict, +) +from MEDS_tabular_automl.file_name import list_subdir_files +from MEDS_tabular_automl.mapper import wrap as rwlock_wrap +from MEDS_tabular_automl.utils import ( + get_shard_prefix, + hydra_loguru_init, + load_tqdm, + store_config_yaml, + write_df, +) + +config_yaml = files("MEDS_tabular_automl").joinpath("configs/describe_codes.yaml") +if not config_yaml.is_file(): + raise FileNotFoundError("Core configuration not successfully installed!") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) +def main( + cfg: DictConfig, +): + """Computes the feature frequencies so we can filter out infrequent events. + + Args: + cfg: The configuration object for the tabularization process. + """ + iter_wrapper = load_tqdm(cfg.tqdm) + if not cfg.loguru_init: + hydra_loguru_init() + + # Store Config + output_dir = Path(cfg.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + store_config_yaml(output_dir / "config.yaml", cfg) + + # Create output dir + input_dir = Path(cfg.input_dir) + input_dir.mkdir(exist_ok=True, parents=True) + + # 0. Identify Output Columns and Frequencies + logger.info("Iterating through shards and caching feature frequencies.") + + def compute_fn(shard_df): + return compute_feature_frequencies(cfg, shard_df) + + def write_fn(df, out_fp): + write_df(df, out_fp) + + def read_fn(in_fp): + return pl.scan_parquet(in_fp) + + # Map: Iterates through shards and caches feature frequencies + train_shards = list_subdir_files(cfg.input_dir, "parquet") + np.random.shuffle(train_shards) + for shard_fp in iter_wrapper(train_shards): + out_fp = (Path(cfg.cache_dir) / get_shard_prefix(cfg.input_dir, shard_fp)).with_suffix( + shard_fp.suffix + ) + rwlock_wrap( + shard_fp, + out_fp, + read_fn, + write_fn, + compute_fn, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) + + logger.info("Summing frequency computations.") + # Reduce: sum the frequency computations + + def compute_fn(freq_df_list): + feature_freqs = defaultdict(int) + for shard_freq_df in freq_df_list: + shard_freq_dict = convert_to_freq_dict(shard_freq_df) + for feature, freq in shard_freq_dict.items(): + feature_freqs[feature] += freq + feature_df = convert_to_df(feature_freqs) + return feature_df + + def write_fn(df, out_fp): + write_df(df, out_fp) + + def read_fn(feature_dir): + files = list_subdir_files(feature_dir, "parquet") + return [pl.scan_parquet(fp) for fp in files] + + rwlock_wrap( + Path(cfg.cache_dir), + Path(cfg.output_filepath), + read_fn, + write_fn, + compute_fn, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) + logger.info("Stored feature columns and frequencies.") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_tabular_automl/scripts/launch_xgboost.py b/src/MEDS_tabular_automl/scripts/launch_xgboost.py new file mode 100644 index 0000000..6babbcc --- /dev/null +++ b/src/MEDS_tabular_automl/scripts/launch_xgboost.py @@ -0,0 +1,425 @@ +from collections.abc import Callable, Mapping +from importlib.resources import files +from pathlib import Path + +import hydra +import numpy as np +import polars as pl +import scipy.sparse as sp +import xgboost as xgb +from loguru import logger +from mixins import TimeableMixin +from omegaconf import DictConfig, OmegaConf +from sklearn.metrics import roc_auc_score + +from MEDS_tabular_automl.describe_codes import get_feature_columns, get_feature_freqs +from MEDS_tabular_automl.file_name import get_model_files, list_subdir_files +from MEDS_tabular_automl.utils import get_feature_indices, hydra_loguru_init + +config_yaml = files("MEDS_tabular_automl").joinpath("configs/launch_xgboost.yaml") +if not config_yaml.is_file(): + raise FileNotFoundError("Core configuration not successfully installed!") + + +class Iterator(xgb.DataIter, TimeableMixin): + """Iterator class for loading and processing data shards. + + This class provides functionality for iterating through data shards, loading + feature data and labels, and processing them based on the provided configuration. + + Args: + cfg: A configuration dictionary containing parameters for + data processing, feature selection, and other settings. + split: The data split to use, which can be one of "train", "tuning", + or "held_out". This determines which subset of the data is loaded and processed. + + Attributes: + cfg: Configuration dictionary containing parameters for + data processing, feature selection, and other settings. + file_name_resolver: Object for resolving file names and paths based on the configuration. + split: The data split being used for loading and processing data shards. + _data_shards: List of data shard names. + valid_event_ids: Dictionary mapping shard number to a list of valid event IDs. + labels: Dictionary mapping shard number to a list of labels for the corresponding event IDs. + codes_set: Set of codes to include in the data. + code_masks: Dictionary of code masks for filtering features based on aggregation. + num_features: Total number of features in the data. + """ + + def __init__(self, cfg: DictConfig, split: str = "train"): + """Initializes the Iterator with the provided configuration and data split. + + Args: + cfg: A configuration dictionary containing parameters for + data processing, feature selection, and other settings. + split: The data split to use, which can be one of "train", "tuning", + or "held_out". This determines which subset of the data is loaded and processed. + """ + # generate_permutations(cfg.tabularization.window_sizes) + # generate_permutations(cfg.tabularization.aggs) + self.cfg = cfg + self.split = split + # Load shards for this split + self._data_shards = sorted( + [shard.stem for shard in list_subdir_files(Path(cfg.input_label_dir) / split, "parquet")] + ) + self.valid_event_ids, self.labels = self.load_labels() + self.codes_set, self.code_masks, self.num_features = self._get_code_set() + self._it = 0 + + super().__init__(cache_prefix=Path(cfg.cache_dir)) + + @TimeableMixin.TimeAs + def _get_code_masks(self, feature_columns: list, codes_set: set) -> Mapping[str, list[bool]]: + """Create boolean masks for filtering features. + + Creates a dictionary of boolean masks for each aggregation type. The masks are used to filter + the feature columns based on the specified included codes and minimum code inclusion frequency. + + Args: + feature_columns: List of feature columns. + codes_set: Set of codes to include. + + Returns: + Dictionary of code masks for each aggregation. + """ + code_masks = {} + for agg in set(self.cfg.tabularization.aggs): + feature_ids = get_feature_indices(agg, feature_columns) + code_mask = [True if idx in codes_set else False for idx in feature_ids] + code_masks[agg] = code_mask + return code_masks + + @TimeableMixin.TimeAs + def _load_matrix(self, path: Path) -> sp.csc_matrix: + """Load a sparse matrix from disk. + + Args: + - path (Path): Path to the sparse matrix. + + Returns: + - sp.csc_matrix: Sparse matrix. + """ + npzfile = np.load(path) + array, shape = npzfile["array"], npzfile["shape"] + if array.shape[0] != 3: + raise ValueError(f"Expected array to have 3 rows, but got {array.shape[0]} rows") + data, row, col = array + return sp.csc_matrix((data, (row, col)), shape=shape) + + @TimeableMixin.TimeAs + def load_labels(self) -> tuple[Mapping[int, list], Mapping[int, list]]: + """Loads valid event ids and labels for each shard. + + Returns: + - Tuple[Mapping[int, list], Mapping[int, list]]: Tuple containing: + dictionary from shard number to list of valid event ids -- used for indexing rows + in the sparse matrix + dictionary from shard number to list of labels for these valid event ids + """ + label_fps = { + shard: (Path(self.cfg.input_label_dir) / self.split / shard).with_suffix(".parquet") + for shard in self._data_shards + for shard in self._data_shards + } + cached_labels, cached_event_ids = dict(), dict() + for shard, label_fp in label_fps.items(): + label_df = pl.scan_parquet(label_fp) + cached_event_ids[shard] = label_df.select(pl.col("event_id")).collect().to_series() + + # TODO: check this for Nan or any other case we need to worry about + cached_labels[shard] = label_df.select(pl.col("label")).collect().to_series() + if self.cfg.model_params.iterator.binarize_task: + cached_labels[shard] = cached_labels[shard].map_elements( + lambda x: 1 if x > 0 else 0, return_dtype=pl.Int8 + ) + + return cached_event_ids, cached_labels + + @TimeableMixin.TimeAs + def _get_code_set(self) -> tuple[set, Mapping[int, list], int]: + """Get the set of codes to include in the data based on the configuration.""" + feature_columns = get_feature_columns(self.cfg.tabularization.filtered_code_metadata_fp) + feature_freqs = get_feature_freqs(self.cfg.tabularization.filtered_code_metadata_fp) + feature_columns = [ + col + for col in feature_columns + if feature_freqs[col] >= self.cfg.tabularization.min_code_inclusion_frequency + ] + feature_dict = {col: i for i, col in enumerate(feature_columns)} + allowed_codes = set(self.cfg.tabularization._resolved_codes) + codes_set = {feature_dict[code] for code in feature_dict if code in allowed_codes} + + return ( + codes_set, + self._get_code_masks(feature_columns, codes_set), + len(feature_columns), + ) + + @TimeableMixin.TimeAs + def _load_dynamic_shard_from_file(self, path: Path, idx: int) -> sp.csc_matrix: + """Load a sparse shard into memory. + + Args: + - path (Path): Path to the sparse shard. + + Returns: + - sp.csc_matrix: Data frame with the sparse shard. + """ + # column_shard is of form event_idx, feature_idx, value + matrix = self._load_matrix(path) + if path.stem in ["first", "present"]: + agg = f"static/{path.stem}" + else: + agg = f"{path.parent.stem}/{path.stem}" + + return self._filter_shard_on_codes_and_freqs(agg, matrix) + + @TimeableMixin.TimeAs + def _get_dynamic_shard_by_index(self, idx: int) -> sp.csc_matrix: + """Load a specific shard of dynamic data from disk and return it as a sparse matrix after filtering + column inclusion. + + Args: + - idx (int): Index of the shard to load. + + Returns: + - sp.csc_matrix: Filtered sparse matrix. + """ + # get all window_size x aggreagation files using the file resolver + files = get_model_files(self.cfg, self.split, self._data_shards[idx]) + + if not all(file.exists() for file in files): + raise ValueError(f"Not all files exist for shard {self._data_shards[idx]}") + + dynamic_cscs = [self._load_dynamic_shard_from_file(file, idx) for file in files] + + fn_name = "_get_dynamic_shard_by_index" + hstack_key = f"{fn_name}/hstack" + self._register_start(key=hstack_key) + + combined_csc = sp.hstack(dynamic_cscs, format="csc") # TODO: check this + # self._register_end(key=hstack_key) + # # Filter Rows + # valid_indices = self.valid_event_ids[shard_name] + # filter_key = f"{fn_name}/filter" + # self._register_start(key=filter_key) + # out = combined_csc[valid_indices, :] + # self._register_end(key=filter_key) + return combined_csc + + @TimeableMixin.TimeAs + def _get_shard_by_index(self, idx: int) -> tuple[sp.csc_matrix, np.ndarray]: + """Load a specific shard of data from disk and concatenate with static data. + + Args: + - idx (int): Index of the shard to load. + + Returns: + - X (scipy.sparse.csc_matrix): Feature data frame.ß + - y (numpy.ndarray): Labels. + """ + dynamic_df = self._get_dynamic_shard_by_index(idx) + label_df = self.labels[self._data_shards[idx]] + return dynamic_df, label_df + + @TimeableMixin.TimeAs + def _filter_shard_on_codes_and_freqs(self, agg: str, df: sp.csc_matrix) -> sp.csc_matrix: + """Filter the dynamic data frame based on the inclusion sets. Given the codes_mask, filter the data + frame to only include columns that are True in the mask. + + Args: + - df (scipy.sparse.csc_matrix): Data frame to filter. + + Returns: + - df (scipy.sparse.sp.csc_matrix): Filtered data frame. + """ + if self.codes_set is None: + return df + + ckey = f"_filter_shard_on_codes_and_freqs/{agg}" + self._register_start(key=ckey) + + df = df[:, self.code_masks[agg]] + + self._register_end(key=ckey) + + return df + + @TimeableMixin.TimeAs + def next(self, input_data: Callable): + """Advance the iterator by 1 step and pass the data to XGBoost. This function is called by XGBoost + during the construction of ``DMatrix`` + + Args: + - input_data (Callable): A function passed by XGBoost with the same signature as `DMatrix`. + + Returns: + - int: 0 if end of iteration, 1 otherwise. + """ + if self._it == len(self._data_shards): + # return 0 to let XGBoost know this is the end of iteration + return 0 + + # input_data is a function passed in by XGBoost who has the exact same signature of + # ``DMatrix`` + X, y = self._get_shard_by_index(self._it) # self._data_shards[self._it]) + input_data(data=sp.csr_matrix(X), label=y) + self._it += 1 + # Return 1 to let XGBoost know we haven't seen all the files yet. + return 1 + + @TimeableMixin.TimeAs + def reset(self): + """Reset the iterator to its beginning.""" + self._it = 0 + + @TimeableMixin.TimeAs + def collect_in_memory(self) -> tuple[sp.csc_matrix, np.ndarray]: + """Collects data from all shards into memory and returns it. + + This method iterates through all data shards, retrieves the feature data + and labels from each shard, and then concatenates them into a single + sparse matrix and a single array, respectively. + + Returns: + A tuple where the first element is a sparse matrix containing the + feature data, and the second element is a numpy array containing the labels. + """ + + X = [] + y = [] + for i in range(len(self._data_shards)): + X_, y_ = self._get_shard_by_index(i) + X.append(X_) + y.append(y_) + + X = sp.vstack(X) + y = np.concatenate(y, axis=0) + return X, y + + +class XGBoostModel(TimeableMixin): + def __init__(self, cfg: DictConfig): + """Initialize the XGBoostClassifier with the provided configuration. + + Args: + - cfg (DictConfig): Configuration dictionary. + """ + + self.cfg = cfg + self.keep_data_in_memory = cfg.model_params.iterator.keep_data_in_memory + + self.itrain = None + self.ituning = None + self.iheld_out = None + + self.dtrain = None + self.dtuning = None + self.dheld_out = None + + self.model = None + + @TimeableMixin.TimeAs + def _train(self): + """Train the model.""" + self.model = xgb.train( + OmegaConf.to_container(self.cfg.model_params.model), + self.dtrain, + num_boost_round=self.cfg.model_params.num_boost_round, + early_stopping_rounds=self.cfg.model_params.early_stopping_rounds, + # nthreads=self.cfg.nthreads, + evals=[(self.dtrain, "train"), (self.dtuning, "tuning")], + ) + + @TimeableMixin.TimeAs + def train(self): + """Train the model.""" + self._build() + self._train() + + @TimeableMixin.TimeAs + def _build(self): + """Build necessary data structures for training.""" + if self.keep_data_in_memory: + self._build_iterators() + self._build_dmatrix_in_memory() + else: + self._build_iterators() + self._build_dmatrix_from_iterators() + + @TimeableMixin.TimeAs + def _build_dmatrix_in_memory(self): + """Build the DMatrix from the data in memory.""" + X_train, y_train = self.itrain.collect_in_memory() + X_tuning, y_tuning = self.ituning.collect_in_memory() + X_held_out, y_held_out = self.iheld_out.collect_in_memory() + self.dtrain = xgb.DMatrix(X_train, label=y_train) + self.dtuning = xgb.DMatrix(X_tuning, label=y_tuning) + self.dheld_out = xgb.DMatrix(X_held_out, label=y_held_out) + + @TimeableMixin.TimeAs + def _build_dmatrix_from_iterators(self): + """Build the DMatrix from the iterators.""" + self.dtrain = xgb.DMatrix(self.itrain) + self.dtuning = xgb.DMatrix(self.ituning) + self.dheld_out = xgb.DMatrix(self.iheld_out) + + @TimeableMixin.TimeAs + def _build_iterators(self): + """Build the iterators for training, validation, and testing.""" + self.itrain = Iterator(self.cfg, split="train") + self.ituning = Iterator(self.cfg, split="tuning") + self.iheld_out = Iterator(self.cfg, split="held_out") + + @TimeableMixin.TimeAs + def evaluate(self) -> float: + """Evaluate the model on the test set. + + Returns: + - float: Evaluation metric (mae). + """ + y_pred = self.model.predict(self.dheld_out) + y_true = self.dheld_out.get_label() + return roc_auc_score(y_true, y_pred) + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) +def main(cfg: DictConfig) -> float: + """Optimize the model based on the provided configuration. + + Args: + - cfg (DictConfig): Configuration dictionary. + + Returns: + - float: Evaluation result. + """ + if not cfg.loguru_init: + hydra_loguru_init() + + model = XGBoostModel(cfg) + model.train() + + print( + "Time Profiling for window sizes ", + f"{cfg.tabularization.window_sizes} and min ", + "code frequency of {cfg.tabularization.min_code_inclusion_frequency}:", + ) + print("Train Time: \n", model._profile_durations()) + print("Train Iterator Time: \n", model.itrain._profile_durations()) + print("Tuning Iterator Time: \n", model.ituning._profile_durations()) + print("Held Out Iterator Time: \n", model.iheld_out._profile_durations()) + + # save model + save_dir = Path(cfg.output_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving the model to directory: {save_dir}") + model.model.save_model(save_dir / "model.json") + auc = model.evaluate() + logger.info(f"AUC: {auc}") + return auc + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_tabular_automl/scripts/sweep_xgboost.py b/src/MEDS_tabular_automl/scripts/sweep_xgboost.py new file mode 100644 index 0000000..3e019b6 --- /dev/null +++ b/src/MEDS_tabular_automl/scripts/sweep_xgboost.py @@ -0,0 +1,85 @@ +import warnings +from copy import deepcopy +from importlib.resources import files +from itertools import combinations + +import hydra +import optuna +from loguru import logger +from omegaconf import DictConfig, OmegaConf, open_dict + +from MEDS_tabular_automl.scripts import launch_xgboost + +warnings.filterwarnings("ignore", category=UserWarning) + +config_yaml = files("MEDS_tabular_automl").joinpath("configs/launch_xgboost.yaml") +if not config_yaml.is_file(): + raise FileNotFoundError("Core configuration not successfully installed!") + + +def generate_permutations(list_of_options): + """Generate all possible permutations of a list of options. + + Args: + - list_of_options (list): List of options. + + Returns: + - list: List of all possible permutations of length > 1 + """ + permutations = [] + for i in range(1, len(list_of_options) + 1): + permutations.extend(list(combinations(list_of_options, r=i))) + return permutations + + +OmegaConf.register_new_resolver("generate_permutations", generate_permutations) + + +def xgboost_singleton(trial: optuna.Trial, config: DictConfig) -> float: + for key, value in config.optuna.params.suggest_categorical.items(): + logger.info(f"Optimizing {key} with {value}") + config.tabularization[key] = trial.suggest_categorical(key, value) + for key, value in config.optuna.params.suggest_float.items(): + with open_dict(config): + config[key] = trial.suggest_float(key, **value) + for key, value in config.optuna.params.suggest_int.items(): + with open_dict(config): + config[key] = trial.suggest_int(key, **value) + return launch_xgboost.main(config) + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) +def main(cfg: DictConfig) -> None: + study = optuna.create_study( + study_name=cfg.optuna.study_name, + storage=cfg.optuna.storage, + load_if_exists=cfg.optuna.load_if_exists, + direction=cfg.optuna.direction, + sampler=cfg.optuna.sampler, + pruner=cfg.optuna.pruner, + ) + study.optimize( + lambda trial: xgboost_singleton(trial, deepcopy(cfg)), + n_trials=cfg.optuna.n_trials, + n_jobs=cfg.optuna.n_jobs, + show_progress_bar=cfg.optuna.show_progress_bar, + ) + print( + "Number of finished trials: ", + len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]), + ) + print( + "Number of pruned trials: ", + len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]), + ) + print("Sampler:", study.sampler) + print("Best trial:") + trial = study.best_trial + print(" Value: ", trial.value) + print(" Params: ") + for key, value in trial.params.items(): + print(f" {key}: {value}") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_tabular_automl/scripts/tabularize_static.py b/src/MEDS_tabular_automl/scripts/tabularize_static.py new file mode 100644 index 0000000..d653ac2 --- /dev/null +++ b/src/MEDS_tabular_automl/scripts/tabularize_static.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +"""Tabularizes static data in MEDS format into tabular representations.""" + +from itertools import product +from pathlib import Path + +import hydra +import numpy as np +import polars as pl + +pl.enable_string_cache() + +from importlib.resources import files + +from omegaconf import DictConfig + +from MEDS_tabular_automl.describe_codes import ( + convert_to_df, + filter_parquet, + filter_to_codes, + get_feature_columns, + get_feature_freqs, +) +from MEDS_tabular_automl.file_name import list_subdir_files +from MEDS_tabular_automl.generate_static_features import get_flat_static_rep +from MEDS_tabular_automl.mapper import wrap as rwlock_wrap +from MEDS_tabular_automl.utils import ( + STATIC_CODE_AGGREGATION, + STATIC_VALUE_AGGREGATION, + get_shard_prefix, + hydra_loguru_init, + load_tqdm, + write_df, +) + +config_yaml = files("MEDS_tabular_automl").joinpath("configs/tabularization.yaml") +if not config_yaml.is_file(): + raise FileNotFoundError("Core configuration not successfully installed!") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) +def main( + cfg: DictConfig, +): + """Writes a flat (historically summarized) representation of the dataset to disk. + + This file caches a set of files useful for building flat representations of the dataset to disk, + suitable for, e.g., sklearn style modeling for downstream tasks. It will produce a few sets of files: + + * A new directory ``self.config.save_dir / "flat_reps"`` which contains the following: + * A subdirectory ``raw`` which contains: (1) a json file with the configuration arguments and (2) a + set of parquet files containing flat (e.g., wide) representations of summarized events per subject, + broken out by split and subject chunk. + * A set of subdirectories ``past/*`` which contains summarized views over the past ``*`` time period + per subject per event, for all time periods in ``window_sizes``, if any. + + Args: + cfg: + MEDS_cohort_dir: directory of MEDS format dataset that is ingested. + tabularized_data_dir: output directory of tabularized data. + min_code_inclusion_frequency: The base feature inclusion frequency that should be used to dictate + what features can be included in the flat representation. It can either be a float, in which + case it applies across all measurements, or `None`, in which case no filtering is applied, or + a dictionary from measurement type to a float dictating a per-measurement-type inclusion + cutoff. + window_sizes: Beyond writing out a raw, per-event flattened representation, the dataset also has + the capability to summarize these flattened representations over the historical windows + specified in this argument. These are strings specifying time deltas, using this syntax: + `link`_. Each window size will be summarized to a separate directory, and will share the same + subject file split as is used in the raw representation files. + codes: A list of codes to include in the flat representation. If `None`, all codes will be included + in the flat representation. + aggs: A list of aggregations to apply to the raw representation. Must have length greater than 0. + n_patients_per_sub_shard: The number of subjects that should be included in each output file. + Lowering this number increases the number of files written, making the process of creating and + leveraging these files slower but more memory efficient. + do_overwrite: If `True`, this function will overwrite the data already stored in the target save + directory. + do_update: bool = True + seed: The seed to use for random number generation. + + .. _link: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.groupby_rolling.html # noqa: E501 + """ + iter_wrapper = load_tqdm(cfg.tqdm) + if not cfg.loguru_init: + hydra_loguru_init() + + # Step 1: Cache the filtered features that will be used in the tabularization process and modeling + # import pdb; pdb.set_trace() + def read_fn(_): + return _ + + def compute_fn(_): + filtered_feature_columns = filter_to_codes( + cfg.tabularization.allowed_codes, + cfg.tabularization.min_code_inclusion_frequency, + cfg.input_code_metadata_fp, + ) + feature_freqs = get_feature_freqs(cfg.input_code_metadata_fp) + filtered_feeature_freqs = { + code: count for code, count in feature_freqs.items() if code in filtered_feature_columns + } + return convert_to_df(filtered_feeature_freqs) + + def write_fn(data, out_fp): + data.write_parquet(out_fp) + + in_fp = Path(cfg.input_code_metadata_fp) + out_fp = Path(cfg.tabularization.filtered_code_metadata_fp) + rwlock_wrap( + in_fp, + out_fp, + read_fn, + write_fn, + compute_fn, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) + # Step 2: Produce static data representation + meds_shard_fps = list_subdir_files(cfg.input_dir, "parquet") + feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp) + + # shuffle tasks + aggs = cfg.tabularization.aggs + static_aggs = [agg for agg in aggs if agg in [STATIC_CODE_AGGREGATION, STATIC_VALUE_AGGREGATION]] + tabularization_tasks = list(product(meds_shard_fps, static_aggs)) + np.random.shuffle(tabularization_tasks) + + for shard_fp, agg in iter_wrapper(tabularization_tasks): + out_fp = ( + Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, shard_fp) / "none" / agg + ).with_suffix(".npz") + if out_fp.exists() and not cfg.do_overwrite: + raise FileExistsError(f"do_overwrite is {cfg.do_overwrite} and {out_fp} exists!") + + def read_fn(in_fp): + return filter_parquet(in_fp, cfg.tabularization._resolved_codes) + + def compute_fn(shard_df): + return get_flat_static_rep( + agg=agg, + feature_columns=feature_columns, + shard_df=shard_df, + ) + + def write_fn(data, out_df): + write_df(data, out_df, do_overwrite=cfg.do_overwrite) + + rwlock_wrap( + shard_fp, + out_fp, + read_fn, + write_fn, + compute_fn, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_tabular_automl/scripts/tabularize_time_series.py b/src/MEDS_tabular_automl/scripts/tabularize_time_series.py new file mode 100644 index 0000000..d3a653d --- /dev/null +++ b/src/MEDS_tabular_automl/scripts/tabularize_time_series.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python + +"""Aggregates time-series data for feature columns across different window sizes.""" +import polars as pl + +pl.enable_string_cache() + +from importlib.resources import files +from itertools import product +from pathlib import Path + +import hydra +import numpy as np +from loguru import logger +from omegaconf import DictConfig + +from MEDS_tabular_automl.describe_codes import filter_parquet, get_feature_columns +from MEDS_tabular_automl.file_name import list_subdir_files +from MEDS_tabular_automl.generate_summarized_reps import generate_summary +from MEDS_tabular_automl.generate_ts_features import get_flat_ts_rep +from MEDS_tabular_automl.mapper import wrap as rwlock_wrap +from MEDS_tabular_automl.utils import ( + STATIC_CODE_AGGREGATION, + STATIC_VALUE_AGGREGATION, + get_shard_prefix, + hydra_loguru_init, + load_tqdm, + write_df, +) + +config_yaml = files("MEDS_tabular_automl").joinpath("configs/tabularization.yaml") +if not config_yaml.is_file(): + raise FileNotFoundError("Core configuration not successfully installed!") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) +def main( + cfg: DictConfig, +): + """Processes time-series data by summarizing it across different windows, creating a flat, summarized + representation of the data for analysis. + + This function orchestrates the data processing pipeline for summarizing time-series data. It loads + data from the tabularize_ts stage, iterates through the pivoted wide dataframes for each split and + shards and then applies a range aggregations across different window sizes defined in the config + The summarized data is then written to disk in a structured directory format. + + Args: + cfg: A configuration dictionary derived from Hydra, containing parameters such as the input data + directory, output directory, and specifics regarding the summarization process (like window + sizes and aggregation functions). + + Workflow: + 1. Set up the environment based on configuration settings. + 2. Load and categorize time-series file paths by their data splits. + 3. Pair code and value files for each split. + 4. For each pair of files in each split: + - Load the dataframes in a lazy manner. + - Summarize the dataframes based on predefined window sizes and aggregation methods. + - Write the summarized dataframe to disk. + + Raises: + FileNotFoundError: If specified directories or files in the configuration are not found. + ValueError: If required columns like 'code' or 'value' are missing in the data files. + """ + iter_wrapper = load_tqdm(cfg.tqdm) + if not cfg.loguru_init: + hydra_loguru_init() + # Produce ts representation + meds_shard_fps = list_subdir_files(cfg.input_dir, "parquet") + feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp) + + # shuffle tasks + aggs = [ + agg + for agg in cfg.tabularization.aggs + if agg not in [STATIC_CODE_AGGREGATION, STATIC_VALUE_AGGREGATION] + ] + tabularization_tasks = list(product(meds_shard_fps, cfg.tabularization.window_sizes, aggs)) + np.random.shuffle(tabularization_tasks) + + # iterate through them + for shard_fp, window_size, agg in iter_wrapper(tabularization_tasks): + out_fp = ( + Path(cfg.output_dir) / get_shard_prefix(cfg.input_dir, shard_fp) / window_size / agg + ).with_suffix(".npz") + + def read_fn(in_fp): + return filter_parquet(in_fp, cfg.tabularization._resolved_codes) + + def compute_fn(shard_df): + # Load Sparse DataFrame + index_df, sparse_matrix = get_flat_ts_rep(agg, feature_columns, shard_df) + + # Summarize data -- applying aggregations on a specific window size + aggregation combination + summary_df = generate_summary( + feature_columns, + index_df, + sparse_matrix, + window_size, + agg, + ) + assert summary_df.shape[1] > 0, "No data found in the summarized dataframe" + + logger.info("Writing pivot file") + return summary_df + + def write_fn(out_matrix, out_fp): + coo_matrix = out_matrix.tocoo() + write_df(coo_matrix, out_fp, do_overwrite=cfg.do_overwrite) + + rwlock_wrap( + shard_fp, + out_fp, + read_fn, + write_fn, + compute_fn, + do_overwrite=cfg.do_overwrite, + do_return=False, + ) + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_tabular_automl/utils.py b/src/MEDS_tabular_automl/utils.py new file mode 100644 index 0000000..dc0ebed --- /dev/null +++ b/src/MEDS_tabular_automl/utils.py @@ -0,0 +1,444 @@ +"""The base class for core dataset processing logic. + +Attributes: + INPUT_DF_T: This defines the type of the allowable input dataframes -- e.g., databases, filepaths, + dataframes, etc. + DF_T: This defines the type of internal dataframes -- e.g. polars DataFrames. +""" +import os +from collections.abc import Mapping +from pathlib import Path + +import hydra +import numpy as np +import polars as pl +import polars.selectors as cs +from loguru import logger +from omegaconf import DictConfig, OmegaConf +from scipy.sparse import coo_array + +DF_T = pl.LazyFrame +WRITE_USE_PYARROW = True +ROW_IDX_NAME = "__row_idx" + +STATIC_CODE_AGGREGATION = "static/present" +STATIC_VALUE_AGGREGATION = "static/first" + +CODE_AGGREGATIONS = [ + "code/count", +] + +VALUE_AGGREGATIONS = [ + "value/count", + "value/has_values_count", + "value/sum", + "value/sum_sqd", + "value/min", + "value/max", +] + + +def hydra_loguru_init() -> None: + """Adds loguru output to the logs that hydra scrapes. + + Must be called from a hydra main! + """ + hydra_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + logger.add(os.path.join(hydra_path, "main.log")) + + +def load_tqdm(use_tqdm): + if use_tqdm: + from tqdm import tqdm + + return tqdm + else: + + def noop(x, **kwargs): + return x + + return noop + + +def parse_static_feature_column(c: str) -> tuple[str, str, str, str]: + parts = c.split("/") + if len(parts) < 3: + raise ValueError(f"Column {c} is not a valid flat feature column!") + return ("/".join(parts[:-2]), parts[-2], parts[-1]) + + +def array_to_sparse_matrix(array: np.ndarray, shape: tuple[int, int]): + assert array.shape[0] == 3 + data, row, col = array + return coo_array((data, (row, col)), shape=shape) + + +def get_min_dtype(array): + return np.result_type(np.min_scalar_type(array.min()), array.max()) + + +def sparse_matrix_to_array(coo_matrix: coo_array): + data, row, col = coo_matrix.data, coo_matrix.row, coo_matrix.col + # Remove invalid indices + valid_indices = (data == 0) | np.isnan(data) + data = data[~valid_indices] + row = row[~valid_indices] + col = col[~valid_indices] + # reduce dtypes + if len(data): + data = data.astype(get_min_dtype(data)) + row = row.astype(get_min_dtype(row)) + col = col.astype(get_min_dtype(col)) + + return np.array([data, row, col]), coo_matrix.shape + + +def store_matrix(coo_matrix: coo_array, fp_path: Path): + array, shape = sparse_matrix_to_array(coo_matrix) + np.savez(fp_path, array=array, shape=shape) + + +def load_matrix(fp_path: Path): + npzfile = np.load(fp_path) + array, shape = npzfile["array"], npzfile["shape"] + return array_to_sparse_matrix(array, shape) + + +def write_df(df: coo_array, fp: Path, **kwargs): + """Write shard to disk.""" + do_overwrite = kwargs.get("do_overwrite", False) + + if not do_overwrite and fp.is_file(): + raise FileExistsError(f"{fp} exists and do_overwrite is {do_overwrite}!") + + fp.parent.mkdir(exist_ok=True, parents=True) + + if isinstance(df, pl.LazyFrame): + df.collect().write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) + elif isinstance(df, pl.DataFrame): + df.write_parquet(fp, use_pyarrow=WRITE_USE_PYARROW) + elif isinstance(df, coo_array): + store_matrix(df, fp) + else: + raise TypeError(f"Unsupported type for df: {type(df)}") + + +def get_static_col_dtype(col: str) -> pl.DataType: + """Gets the appropriate minimal dtype for the given flat representation column string.""" + + code, code_type, agg = parse_static_feature_column(col) + + match agg: + case "sum" | "sum_sqd" | "min" | "max" | "value" | "first": + return pl.Float32 + case "present": + return pl.Boolean + case "count" | "has_values_count": + return pl.UInt32 + case _: + raise ValueError(f"Column name {col} malformed!") + + +def add_static_missing_cols( + flat_df: DF_T, feature_columns: list[str], set_count_0_to_null: bool = False +) -> DF_T: + """Normalizes columns in a DataFrame so all expected columns are present and appropriately typed. + + Parameters: + - flat_df (DF_T): The DataFrame to be normalized. + - feature_columns (list[str]): A list of feature column names that should exist in the DataFrame. + - set_count_0_to_null (bool): A flag indicating whether counts of zero should be converted to nulls. + + Returns: + - DF_T: The normalized DataFrame with all columns set to the correct type and zero-counts handled + if specified. + + This function ensures that all necessary columns are added and typed correctly within + a DataFrame, potentially modifying zero counts to nulls based on the configuration. + """ + cols_to_add = set(feature_columns) - set(flat_df.columns) + cols_to_retype = set(feature_columns).intersection(set(flat_df.columns)) + + cols_to_add = [(c, get_static_col_dtype(c)) for c in cols_to_add] + cols_to_retype = [(c, get_static_col_dtype(c)) for c in cols_to_retype] + + if "timestamp" in flat_df.columns: + key_cols = ["patient_id", "timestamp"] + else: + key_cols = ["patient_id"] + + flat_df = flat_df.with_columns( + *[pl.lit(None, dtype=dt).alias(c) for c, dt in cols_to_add], + *[pl.col(c).cast(dt).alias(c) for c, dt in cols_to_retype], + ).select(*key_cols, *feature_columns) + + if not set_count_0_to_null: + return flat_df + + flat_df = flat_df.collect() + + flat_df = flat_df.with_columns( + pl.when(cs.ends_with("count") != 0).then(cs.ends_with("count")).keep_name() + ).lazy() + return flat_df + + +def get_static_feature_cols(shard_df) -> list[str]: + """Generates a list of feature column names from the data within each shard based on specified + configurations. + + Parameters: + - cfg (dict): Configuration dictionary specifying how features should be evaluated and aggregated. + - split_to_shard_df (dict): A dictionary of DataFrames, divided by data split (e.g., 'train', 'test'). + + Returns: + - tuple[list[str], dict]: A tuple containing a list of feature columns and a dictionary of code properties + identified during the evaluation. + + This function evaluates the properties of codes within training data and applies configured + aggregations to generate a comprehensive list of feature columns for modeling purposes. + Examples: + >>> import polars as pl + >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'], + ... 'timestamp': [ + ... None, '2021-01-01', '2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04', None + ... ], + ... 'numerical_value': [1, None, 2, 2, None, None, 3]} + >>> df = pl.DataFrame(data).lazy() + >>> get_static_feature_cols(df) + ['A/static/first', 'A/static/present', 'C/static/first', 'C/static/present'] + """ + feature_columns = [] + static_df = shard_df.filter(pl.col("timestamp").is_null()) + for code in static_df.select(pl.col("code").unique()).collect().to_series(): + static_aggregations = [f"{code}/static/present", f"{code}/static/first"] + feature_columns.extend(static_aggregations) + return sorted(feature_columns) + + +def get_ts_feature_cols(shard_df: DF_T) -> list[str]: + """Generates a list of feature column names from the data within each shard based on specified + configurations. + + Parameters: + - cfg (dict): Configuration dictionary specifying how features should be evaluated and aggregated. + - split_to_shard_df (dict): A dictionary of DataFrames, divided by data split (e.g., 'train', 'test'). + + Returns: + - tuple[list[str], dict]: A tuple containing a list of feature columns and a dictionary of code properties + identified during the evaluation. + + This function evaluates the properties of codes within training data and applies configured + aggregations to generate a comprehensive list of feature columns for modeling purposes. + Examples: + >>> import polars as pl + >>> data = {'code': ['A', 'A', 'B', 'B', 'C', 'C', 'C'], + ... 'timestamp': [None, '2021-01-01', None, None, '2021-01-03', '2021-01-04', None], + ... 'numerical_value': [1, None, 2, 2, None, None, 3]} + >>> df = pl.DataFrame(data).lazy() + >>> aggs = ['value/sum', 'code/count'] + >>> get_ts_feature_cols(aggs, df) + ['A/code', 'A/value', 'C/code', 'C/value'] + """ + ts_df = shard_df.filter(pl.col("timestamp").is_not_null()) + feature_columns = list(ts_df.select(pl.col("code").unique()).collect().to_series()) + feature_columns = [f"{code}/code" for code in feature_columns] + [ + f"{code}/value" for code in feature_columns + ] + return sorted(feature_columns) + + +def get_prediction_ts_cols( + aggregations: list[str], ts_feature_cols: DF_T, window_sizes: list[str] | None = None +) -> list[str]: + """Generates a list of feature column names that will be used for downstream task + Examples: + >>> feature_cols = ['A/code', 'A/value', 'C/code', 'C/value'] + >>> window_sizes = None + >>> aggs = ['value/sum', 'code/count'] + >>> get_prediction_ts_cols(aggs, feature_cols, window_sizes) + error + >>> window_sizes = ["1d"] + >>> get_prediction_ts_cols(aggs, feature_cols, window_sizes) + error + """ + agg_feature_columns = [] + for code in ts_feature_cols: + ts_aggregations = [f"{code}/{agg}" for agg in aggregations] + agg_feature_columns.extend(ts_aggregations) + if window_sizes: + ts_aggregations = [f"{window_size}/{code}" for window_size in window_sizes] + return sorted(ts_aggregations) + + +def get_flat_rep_feature_cols(cfg: DictConfig, shard_df: DF_T) -> list[str]: + """Generates a list of feature column names from the data within each shard based on specified + configurations. + + Parameters: + - cfg (dict): Configuration dictionary specifying how features should be evaluated and aggregated. + - shard_df (DF_T): MEDS format dataframe shard. + + Returns: + - list[str]: list of all feature columns. + + This function evaluates the properties of codes within training data and applies configured + aggregations to generate a comprehensive list of feature columns for modeling purposes. + Example: + >>> data = {'code': ['A', 'A', 'B', 'B'], + ... 'timestamp': [None, '2021-01-01', None, None], + ... 'numerical_value': [1, None, 2, 2]} + >>> df = pl.DataFrame(data).lazy() + >>> aggs = ['value/sum', 'code/count'] + >>> cfg = DictConfig({'aggs': aggs}) + >>> get_flat_rep_feature_cols(cfg, df) # doctest: +NORMALIZE_WHITESPACE + ['A/static/first', 'A/static/present', 'B/static/first', 'B/static/present', 'A/code/count', + 'A/value/sum'] + """ + static_feature_columns = get_static_feature_cols(shard_df) + ts_feature_columns = get_ts_feature_cols(cfg.aggs, shard_df) + return static_feature_columns + ts_feature_columns + + +def load_meds_data(MEDS_cohort_dir: str, load_data: bool = True) -> Mapping[str, pl.DataFrame]: + """Loads the MEDS dataset from disk. + + Args: + MEDS_cohort_dir: The directory containing the MEDS datasets split by subfolders. + We expect `train` to be a split so `MEDS_cohort_dir/train` should exist. + + Returns: + Mapping[str, pl.DataFrame]: Mapping from split name to a polars DataFrame containing the MEDS dataset. + + Example: + >>> import tempfile + >>> from pathlib import Path + >>> MEDS_cohort_dir = Path(tempfile.mkdtemp()) + >>> for split in ["train", "val", "test"]: + ... split_dir = MEDS_cohort_dir / split + ... split_dir.mkdir() + ... pl.DataFrame({"patient_id": [1, 2, 3]}).write_parquet(split_dir / "data.parquet") + >>> split_to_df = load_meds_data(MEDS_cohort_dir) + >>> assert "train" in split_to_df + >>> assert len(split_to_df) == 3 + >>> assert len(split_to_df["train"]) == 1 + >>> assert isinstance(split_to_df["train"][0], pl.LazyFrame) + """ + MEDS_cohort_dir = Path(MEDS_cohort_dir) + meds_fps = list(MEDS_cohort_dir.glob("*/*.parquet")) + splits = {fp.parent.stem for fp in meds_fps} + split_to_fps = {split: [fp for fp in meds_fps if fp.parent.stem == split] for split in splits} + if not load_data: + return split_to_fps + split_to_df = { + split: [pl.scan_parquet(fp) for fp in split_fps] for split, split_fps in split_to_fps.items() + } + return split_to_df + + +def get_events_df(shard_df: pl.DataFrame, feature_columns) -> pl.DataFrame: + """Extracts Events DataFrame with one row per observation (timestamps can be duplicated)""" + # Filter out feature_columns that were not present in the training set + raw_feature_columns = ["/".join(c.split("/")[:-1]) for c in feature_columns] + shard_df = shard_df.filter(pl.col("code").is_in(raw_feature_columns)) + # Drop rows with missing timestamp or code to get events + ts_shard_df = shard_df.drop_nulls(subset=["timestamp", "code"]) + return ts_shard_df + + +def get_unique_time_events_df(events_df: pl.DataFrame): + """Updates Events DataFrame to have unique timestamps and sorted by patient_id and timestamp.""" + assert events_df.select(pl.col("timestamp")).null_count().collect().item() == 0 + # Check events_df is sorted - so it aligns with the ts_matrix we generate later in the pipeline + events_df = ( + events_df.drop_nulls("timestamp") + .select(pl.col(["patient_id", "timestamp"])) + .unique(maintain_order=True) + ) + assert events_df.sort(by=["patient_id", "timestamp"]).collect().equals(events_df.collect()) + return events_df + + +def get_feature_names(agg, feature_columns) -> str: + """Indices of columns in feature_columns list.""" + if agg in [STATIC_CODE_AGGREGATION, STATIC_VALUE_AGGREGATION]: + return [c for c in feature_columns if c.endswith(agg)] + elif agg in CODE_AGGREGATIONS: + return [c for c in feature_columns if c.endswith("/code")] + elif agg in VALUE_AGGREGATIONS: + return [c for c in feature_columns if c.endswith("/value")] + else: + raise ValueError(f"Unknown aggregation type {agg}") + + +def get_feature_indices(agg, feature_columns) -> str: + """Indices of columns in feature_columns list.""" + feature_to_index = {c: i for i, c in enumerate(feature_columns)} + agg_features = get_feature_names(agg, feature_columns) + return [feature_to_index[c] for c in agg_features] + + +def store_config_yaml(config_fp: Path, cfg: DictConfig): + """Stores configuration parameters into a JSON file. + + This function writes a dictionary of parameters, which includes patient partitioning + information and configuration details, to a specified JSON file. + + Args: + - config_fp (Path): The file path for the JSON file where config should be stored. + - cfg (DictConfig): A configuration object containing settings like the number of patients + per sub-shard, minimum code inclusion frequency, and flags for updating or overwriting existing files. + + Behavior: + - If config_fp exists and cfg.do_overwrite is False (without do_update being True), a + FileExistsError is raised to prevent unintentional data loss. + + Raises: + - ValueError: If there are discrepancies between old and new parameters during an update. + - FileExistsError: If the file exists and overwriting is not allowed. + + Example: + >>> cfg = DictConfig({ + ... "n_patients_per_sub_shard": 100, + ... "min_code_inclusion_frequency": 5, + ... "do_overwrite": True, + ... }) + >>> import tempfile + >>> from pathlib import Path + >>> with tempfile.NamedTemporaryFile() as temp_f: + ... config_fp = Path(temp_f.name) + ... store_config_yaml(config_fp, cfg) + ... assert config_fp.exists() + ... store_config_yaml(config_fp, cfg) + ... cfg.do_overwrite = False + ... try: + ... store_config_yaml(config_fp, cfg) + ... except FileExistsError as e: + ... print("FileExistsError Error Triggered") + FileExistsError Error Triggered + """ + OmegaConf.save(cfg, config_fp) + + +def get_shard_prefix(base_path: Path, fp: Path) -> str: + """Extracts the shard prefix from a file path by removing the raw_cohort_dir. + + Args: + base_path: The base path to remove. + fp: The file path to extract the shard prefix from. + + Returns: + The shard prefix (the file path relative to the base path with the suffix removed). + + Examples: + >>> get_shard_prefix(Path("/a/b/c"), Path("/a/b/c/d.parquet")) + 'd' + >>> get_shard_prefix(Path("/a/b/c"), Path("/a/b/c/d/e.csv.gz")) + 'd/e' + """ + + relative_path = fp.relative_to(base_path) + relative_parent = relative_path.parent + file_name = relative_path.name.split(".")[0] + + return str(relative_parent / file_name) diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..0a751b6 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,288 @@ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +import json +import subprocess +import tempfile +from io import StringIO +from pathlib import Path + +import polars as pl +from hydra import compose, initialize +from test_tabularize import ( + CODE_COLS, + EXPECTED_STATIC_FILES, + MEDS_OUTPUTS, + SPLITS_JSON, + STATIC_FIRST_COLS, + STATIC_PRESENT_COLS, + SUMMARIZE_EXPECTED_FILES, + VALUE_COLS, +) + +from MEDS_tabular_automl.describe_codes import get_feature_columns +from MEDS_tabular_automl.file_name import list_subdir_files +from MEDS_tabular_automl.utils import ( + VALUE_AGGREGATIONS, + get_events_df, + get_feature_names, + get_shard_prefix, + get_unique_time_events_df, + load_matrix, +) + + +def run_command(script: str, args: list[str], hydra_kwargs: dict[str, str], test_name: str): + command_parts = [script] + args + [f"{k}={v}" for k, v in hydra_kwargs.items()] + command_out = subprocess.run(" ".join(command_parts), shell=True, capture_output=True) + stderr = command_out.stderr.decode() + stdout = command_out.stdout.decode() + if command_out.returncode != 0: + raise AssertionError(f"{test_name} failed!\nstdout:\n{stdout}\nstderr:\n{stderr}") + return stderr, stdout + + +def test_integration(): + # Step 0: Setup Environment + with tempfile.TemporaryDirectory() as d: + MEDS_cohort_dir = Path(d) / "processed" + + describe_codes_config = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + } + + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in describe_codes_config.items()] + cfg = compose(config_name="describe_codes", overrides=overrides) # config.yaml + + # Create the directories + (MEDS_cohort_dir / "final_cohort").mkdir(parents=True, exist_ok=True) + + # Store MEDS outputs + for split, data in MEDS_OUTPUTS.items(): + file_path = MEDS_cohort_dir / "final_cohort" / f"{split}.parquet" + file_path.parent.mkdir(exist_ok=True) + df = pl.read_csv(StringIO(data)) + df.with_columns(pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet( + file_path + ) + + # Check the files are not empty + meds_files = list_subdir_files(Path(cfg.input_dir), "parquet") + assert ( + len(list_subdir_files(Path(cfg.input_dir).parent, "parquet")) == 4 + ), "MEDS train split Data Files Should be 4!" + for f in meds_files: + assert pl.read_parquet(f).shape[0] > 0, "MEDS Data Tabular Dataframe Should not be Empty!" + split_json = json.load(StringIO(SPLITS_JSON)) + splits_fp = MEDS_cohort_dir / "splits.json" + json.dump(split_json, splits_fp.open("w")) + + # Step 1: Run the describe_codes script + stderr, stdout = run_command( + "meds-tab-describe", + [], + describe_codes_config, + "describe_codes", + ) + assert (Path(cfg.output_dir) / "config.yaml").is_file() + assert Path(cfg.output_filepath).is_file() + + feature_columns = get_feature_columns(cfg.output_filepath) + assert get_feature_names("code/count", feature_columns) == sorted(CODE_COLS) + assert get_feature_names("static/present", feature_columns) == sorted(STATIC_PRESENT_COLS) + assert get_feature_names("static/first", feature_columns) == sorted(STATIC_FIRST_COLS) + for value_agg in VALUE_AGGREGATIONS: + assert get_feature_names(value_agg, feature_columns) == sorted(VALUE_COLS) + + # Step 2: Run the static data tabularization script + tabularize_config = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + "tabularization.min_code_inclusion_frequency": 1, + "tabularization.aggs": "[static/present,static/first,code/count,value/sum]", + "tabularization.window_sizes": "[30d,365d,full]", + } + stderr, stdout = run_command( + "meds-tab-tabularize-static", + [], + tabularize_config, + "tabularization", + ) + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in tabularize_config.items()] + cfg = compose(config_name="tabularization", overrides=overrides) # config.yaml + + output_files = list(Path(cfg.output_dir).glob("**/static/**/*.npz")) + actual_files = [get_shard_prefix(Path(cfg.output_dir), each) + ".npz" for each in output_files] + assert set(actual_files) == set(EXPECTED_STATIC_FILES) + # Check the files are not empty + for f in output_files: + static_matrix = load_matrix(f) + assert static_matrix.shape[0] > 0, "Static Data Tabular Dataframe Should not be Empty!" + expected_num_cols = len(get_feature_names(f"static/{f.stem}", feature_columns)) + assert static_matrix.shape[1] == expected_num_cols, ( + f"Static Data Tabular Dataframe Should have {expected_num_cols}" + f"Columns but has {static_matrix.shape[1]}!" + ) + split = f.parts[-5] + shard_num = f.parts[-4] + med_shard_fp = (Path(cfg.input_dir) / split / shard_num).with_suffix(".parquet") + expected_num_rows = ( + get_unique_time_events_df(get_events_df(pl.scan_parquet(med_shard_fp), feature_columns)) + .collect() + .shape[0] + ) + assert static_matrix.shape[0] == expected_num_rows, ( + f"Static Data matrix Should have {expected_num_rows}" + f" rows but has {static_matrix.shape[0]}!" + ) + allowed_codes = cfg.tabularization._resolved_codes + num_allowed_codes = len(allowed_codes) + feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp) + assert num_allowed_codes == len( + feature_columns + ), f"Should have {len(feature_columns)} codes but has {num_allowed_codes}" + + # Step 3: Run the time series tabularization script + tabularize_config = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + "tabularization.min_code_inclusion_frequency": 1, + "tabularization.aggs": "[static/present,static/first,code/count,value/sum]", + "tabularization.window_sizes": "[30d,365d,full]", + } + + stderr, stdout = run_command( + "meds-tab-tabularize-time-series", + ["--multirun", 'worker="range(0,1)"', "hydra/launcher=joblib"], + tabularize_config, + "tabularization", + ) + + # confirm summary files exist: + output_files = list_subdir_files(cfg.output_dir, "npz") + actual_files = [ + get_shard_prefix(Path(cfg.output_dir), each) + ".npz" + for each in output_files + if "none/static" not in str(each) + ] + assert set(actual_files) == set(SUMMARIZE_EXPECTED_FILES) + for f in output_files: + ts_matrix = load_matrix(f) + assert ts_matrix.shape[0] > 0, "Time-Series Tabular Dataframe Should not be Empty!" + expected_num_cols = len(get_feature_names(f"{f.parent.stem}/{f.stem}", feature_columns)) + assert ts_matrix.shape[1] == expected_num_cols, ( + f"Time-Series Tabular Dataframe Should have {expected_num_cols}" + f"Columns but has {ts_matrix.shape[1]}!" + ) + split = f.parts[-5] + shard_num = f.parts[-4] + med_shard_fp = (Path(cfg.input_dir) / split / shard_num).with_suffix(".parquet") + expected_num_rows = ( + get_unique_time_events_df(get_events_df(pl.scan_parquet(med_shard_fp), feature_columns)) + .collect() + .shape[0] + ) + assert ts_matrix.shape[0] == expected_num_rows, ( + f"Time-Series Data matrix Should have {expected_num_rows}" + f" rows but has {ts_matrix.shape[0]}!" + ) + + # Step 4: Run the task_specific_caching script + cache_config = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + "tabularization.min_code_inclusion_frequency": 1, + "tabularization.aggs": "[static/present,static/first,code/count,value/sum]", + "tabularization.window_sizes": "[30d,365d,full]", + } + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in cache_config.items()] + cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml + # Create fake labels + for f in list_subdir_files(Path(cfg.MEDS_cohort_dir) / "final_cohort", "parquet"): + df = pl.scan_parquet(f) + df = get_unique_time_events_df(get_events_df(df, feature_columns)).collect() + pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]]) + df = df.with_columns(pl.Series(name="label", values=pseudo_labels)) + df = df.select(pl.col(["patient_id", "timestamp", "label"])) + df = df.with_row_index("event_id") + + split = f.parent.stem + shard_num = f.stem + out_f = Path(cfg.input_label_dir) / Path( + get_shard_prefix(Path(cfg.MEDS_cohort_dir) / "final_cohort", f) + ).with_suffix(".parquet") + out_f.parent.mkdir(parents=True, exist_ok=True) + df.write_parquet(out_f) + + stderr, stdout = run_command( + "meds-tab-cache-task", + [], + cache_config, + "task_specific_caching", + ) + # Check the files are not empty + + # Step 5: Run the xgboost script + + xgboost_config_kwargs = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + "tabularization.min_code_inclusion_frequency": 1, + "tabularization.aggs": "[static/present,static/first,code/count,value/sum]", + "tabularization.window_sizes": "[30d,365d,full]", + } + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()] + cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml + stderr, stdout = run_command( + "meds-tab-xgboost", + [], + xgboost_config_kwargs, + "xgboost", + ) + output_files = list(Path(cfg.output_dir).parent.glob("**/*.json")) + assert len(output_files) == 1 + assert output_files[0].stem == "model" + + stderr, stdout = run_command( + "meds-tab-xgboost-sweep", + [], + xgboost_config_kwargs, + "xgboost-sweep", + ) + output_files = list(Path(cfg.output_dir).parent.glob("**/*.json")) + assert len(output_files) == 2 + assert output_files[0].stem == "model" diff --git a/tests/test_tabularize.py b/tests/test_tabularize.py new file mode 100644 index 0000000..ca67465 --- /dev/null +++ b/tests/test_tabularize.py @@ -0,0 +1,419 @@ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +import json +import os +import tempfile +from io import StringIO +from pathlib import Path + +import polars as pl +from hydra import compose, initialize + +from MEDS_tabular_automl.describe_codes import get_feature_columns +from MEDS_tabular_automl.file_name import list_subdir_files +from MEDS_tabular_automl.scripts import ( + cache_task, + describe_codes, + launch_xgboost, + sweep_xgboost, + tabularize_static, + tabularize_time_series, +) +from MEDS_tabular_automl.utils import ( + VALUE_AGGREGATIONS, + get_events_df, + get_feature_names, + get_shard_prefix, + get_unique_time_events_df, + load_matrix, +) + +SPLITS_JSON = """{"train/0": [239684, 1195293], "train/1": [68729, 814703], "tuning/0": [754281], "held_out/0": [1500733]}""" # noqa: E501 + +MEDS_TRAIN_0 = """ +patient_id,code,timestamp,numerical_value +239684,HEIGHT,,175.271115221764 +239684,EYE_COLOR//BROWN,, +239684,DOB,1980-12-28T00:00:00.000000, +239684,TEMP,2010-05-11T17:41:51.000000,96.0 +239684,ADMISSION//CARDIAC,2010-05-11T17:41:51.000000, +239684,HR,2010-05-11T17:41:51.000000,102.6 +239684,TEMP,2010-05-11T17:48:48.000000,96.2 +239684,HR,2010-05-11T17:48:48.000000,105.1 +239684,TEMP,2010-05-11T18:25:35.000000,95.8 +239684,HR,2010-05-11T18:25:35.000000,113.4 +239684,HR,2010-05-11T18:57:18.000000,112.6 +239684,TEMP,2010-05-11T18:57:18.000000,95.5 +239684,DISCHARGE,2010-05-11T19:27:19.000000, +1195293,HEIGHT,,164.6868838269085 +1195293,EYE_COLOR//BLUE,, +1195293,DOB,1978-06-20T00:00:00.000000, +1195293,TEMP,2010-06-20T19:23:52.000000,100.0 +1195293,ADMISSION//CARDIAC,2010-06-20T19:23:52.000000, +1195293,HR,2010-06-20T19:23:52.000000,109.0 +1195293,TEMP,2010-06-20T19:25:32.000000,100.0 +1195293,HR,2010-06-20T19:25:32.000000,114.1 +1195293,HR,2010-06-20T19:45:19.000000,119.8 +1195293,TEMP,2010-06-20T19:45:19.000000,99.9 +1195293,HR,2010-06-20T20:12:31.000000,112.5 +1195293,TEMP,2010-06-20T20:12:31.000000,99.8 +1195293,HR,2010-06-20T20:24:44.000000,107.7 +1195293,TEMP,2010-06-20T20:24:44.000000,100.0 +1195293,TEMP,2010-06-20T20:41:33.000000,100.4 +1195293,HR,2010-06-20T20:41:33.000000,107.5 +1195293,DISCHARGE,2010-06-20T20:50:04.000000, +""" +MEDS_TRAIN_1 = """ +patient_id,code,timestamp,numerical_value +68729,EYE_COLOR//HAZEL,, +68729,HEIGHT,,160.3953106166676 +68729,DOB,1978-03-09T00:00:00.000000, +68729,HR,2010-05-26T02:30:56.000000,86.0 +68729,ADMISSION//PULMONARY,2010-05-26T02:30:56.000000, +68729,TEMP,2010-05-26T02:30:56.000000,97.8 +68729,DISCHARGE,2010-05-26T04:51:52.000000, +814703,EYE_COLOR//HAZEL,, +814703,HEIGHT,,156.48559093209357 +814703,DOB,1976-03-28T00:00:00.000000, +814703,TEMP,2010-02-05T05:55:39.000000,100.1 +814703,HR,2010-02-05T05:55:39.000000,170.2 +814703,ADMISSION//ORTHOPEDIC,2010-02-05T05:55:39.000000, +814703,DISCHARGE,2010-02-05T07:02:30.000000, +""" +MEDS_HELD_OUT_0 = """ +patient_id,code,timestamp,numerical_value +1500733,HEIGHT,,158.60131573580904 +1500733,EYE_COLOR//BROWN,, +1500733,DOB,1986-07-20T00:00:00.000000, +1500733,TEMP,2010-06-03T14:54:38.000000,100.0 +1500733,HR,2010-06-03T14:54:38.000000,91.4 +1500733,ADMISSION//ORTHOPEDIC,2010-06-03T14:54:38.000000, +1500733,HR,2010-06-03T15:39:49.000000,84.4 +1500733,TEMP,2010-06-03T15:39:49.000000,100.3 +1500733,HR,2010-06-03T16:20:49.000000,90.1 +1500733,TEMP,2010-06-03T16:20:49.000000,100.1 +1500733,DISCHARGE,2010-06-03T16:44:26.000000, +""" +MEDS_TUNING_0 = """ +patient_id,code,timestamp,numerical_value +754281,EYE_COLOR//BROWN,, +754281,HEIGHT,,166.22261567137025 +754281,DOB,1988-12-19T00:00:00.000000, +754281,ADMISSION//PULMONARY,2010-01-03T06:27:59.000000, +754281,TEMP,2010-01-03T06:27:59.000000,99.8 +754281,HR,2010-01-03T06:27:59.000000,142.0 +754281,DISCHARGE,2010-01-03T08:22:13.000000, +""" + +MEDS_OUTPUTS = { + "train/0": MEDS_TRAIN_0, + "train/1": MEDS_TRAIN_1, + "held_out/0": MEDS_HELD_OUT_0, + "tuning/0": MEDS_TUNING_0, +} + +CODE_COLS = [ + "ADMISSION//CARDIAC/code", + "ADMISSION//ORTHOPEDIC/code", + "ADMISSION//PULMONARY/code", + "DISCHARGE/code", + "DOB/code", + "HR/code", + "TEMP/code", +] +VALUE_COLS = ["HR/value", "TEMP/value"] +STATIC_PRESENT_COLS = [ + "EYE_COLOR//BLUE/static/present", + "EYE_COLOR//BROWN/static/present", + "EYE_COLOR//HAZEL/static/present", + "HEIGHT/static/present", +] +STATIC_FIRST_COLS = ["HEIGHT/static/first"] + +EXPECTED_STATIC_FILES = [ + "held_out/0/none/static/first.npz", + "held_out/0/none/static/present.npz", + "train/0/none/static/first.npz", + "train/0/none/static/present.npz", + "train/1/none/static/first.npz", + "train/1/none/static/present.npz", + "tuning/0/none/static/first.npz", + "tuning/0/none/static/present.npz", +] + +SUMMARIZE_EXPECTED_FILES = [ + "train/1/365d/value/sum.npz", + "train/1/365d/code/count.npz", + "train/1/full/value/sum.npz", + "train/1/full/code/count.npz", + "train/1/30d/value/sum.npz", + "train/1/30d/code/count.npz", + "train/0/365d/value/sum.npz", + "train/0/365d/code/count.npz", + "train/0/full/value/sum.npz", + "train/0/full/code/count.npz", + "train/0/30d/value/sum.npz", + "train/0/30d/code/count.npz", + "held_out/0/365d/value/sum.npz", + "held_out/0/365d/code/count.npz", + "held_out/0/full/value/sum.npz", + "held_out/0/full/code/count.npz", + "held_out/0/30d/value/sum.npz", + "held_out/0/30d/code/count.npz", + "tuning/0/365d/value/sum.npz", + "tuning/0/365d/code/count.npz", + "tuning/0/full/value/sum.npz", + "tuning/0/full/code/count.npz", + "tuning/0/30d/value/sum.npz", + "tuning/0/30d/code/count.npz", +] + +MERGE_EXPECTED_FILES = [ + "train/365d/value/sum/0.npz", + "train/365d/value/sum/1.npz", + "train/365d/code/count/0.npz", + "train/365d/code/count/1.npz", + "train/full/value/sum/0.npz", + "train/full/value/sum/1.npz", + "train/full/code/count/0.npz", + "train/full/code/count/1.npz", + "train/30d/value/sum/0.npz", + "train/30d/value/sum/1.npz", + "train/30d/code/count/0.npz", + "train/30d/code/count/1.npz", + "held_out/365d/value/sum/0.npz", + "held_out/365d/code/count/0.npz", + "held_out/full/value/sum/0.npz", + "held_out/full/code/count/0.npz", + "held_out/30d/value/sum/0.npz", + "held_out/30d/code/count/0.npz", + "tuning/365d/value/sum/0.npz", + "tuning/365d/code/count/0.npz", + "tuning/full/value/sum/0.npz", + "tuning/full/code/count/0.npz", + "tuning/30d/value/sum/0.npz", + "tuning/30d/code/count/0.npz", +] + + +def test_tabularize(): + with tempfile.TemporaryDirectory() as d: + MEDS_cohort_dir = Path(d) / "processed" + + describe_codes_config = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + } + + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in describe_codes_config.items()] + cfg = compose(config_name="describe_codes", overrides=overrides) # config.yaml + + # Create the directories + (MEDS_cohort_dir / "final_cohort").mkdir(parents=True, exist_ok=True) + + # Store MEDS outputs + for split, data in MEDS_OUTPUTS.items(): + file_path = MEDS_cohort_dir / "final_cohort" / f"{split}.parquet" + file_path.parent.mkdir(exist_ok=True) + df = pl.read_csv(StringIO(data)) + df.with_columns(pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%.f")).write_parquet( + file_path + ) + + # Check the files are not empty + meds_files = list_subdir_files(Path(cfg.input_dir), "parquet") + assert ( + len(list_subdir_files(Path(cfg.input_dir).parent, "parquet")) == 4 + ), "MEDS train split Data Files Should be 4!" + for f in meds_files: + assert pl.read_parquet(f).shape[0] > 0, "MEDS Data Tabular Dataframe Should not be Empty!" + split_json = json.load(StringIO(SPLITS_JSON)) + splits_fp = MEDS_cohort_dir / "splits.json" + json.dump(split_json, splits_fp.open("w")) + # Step 1: Describe Codes - compute code frequencies + describe_codes.main(cfg) + + assert (Path(cfg.output_dir) / "config.yaml").is_file() + assert Path(cfg.output_filepath).is_file() + + feature_columns = get_feature_columns(cfg.output_filepath) + assert get_feature_names("code/count", feature_columns) == sorted(CODE_COLS) + assert get_feature_names("static/present", feature_columns) == sorted(STATIC_PRESENT_COLS) + assert get_feature_names("static/first", feature_columns) == sorted(STATIC_FIRST_COLS) + for value_agg in VALUE_AGGREGATIONS: + assert get_feature_names(value_agg, feature_columns) == sorted(VALUE_COLS) + + # Step 2: Tabularization + tabularize_static_config = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + "tabularization.min_code_inclusion_frequency": 1, + "tabularization.aggs": "[static/present,static/first,code/count,value/sum]", + "tabularization.window_sizes": "[30d,365d,full]", + } + + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in tabularize_static_config.items()] + cfg = compose(config_name="tabularization", overrides=overrides) # config.yaml + tabularize_static.main(cfg) + output_files = list(Path(cfg.output_dir).glob("**/static/**/*.npz")) + actual_files = [get_shard_prefix(Path(cfg.output_dir), each) + ".npz" for each in output_files] + assert set(actual_files) == set(EXPECTED_STATIC_FILES) + # Check the files are not empty + for f in output_files: + static_matrix = load_matrix(f) + assert static_matrix.shape[0] > 0, "Static Data Tabular Dataframe Should not be Empty!" + expected_num_cols = len(get_feature_names(f"static/{f.stem}", feature_columns)) + assert static_matrix.shape[1] == expected_num_cols, ( + f"Static Data Tabular Dataframe Should have {expected_num_cols}" + f"Columns but has {static_matrix.shape[1]}!" + ) + split = f.parts[-5] + shard_num = f.parts[-4] + med_shard_fp = (Path(cfg.input_dir) / split / shard_num).with_suffix(".parquet") + expected_num_rows = ( + get_unique_time_events_df(get_events_df(pl.scan_parquet(med_shard_fp), feature_columns)) + .collect() + .shape[0] + ) + assert static_matrix.shape[0] == expected_num_rows, ( + f"Static Data matrix Should have {expected_num_rows}" + f" rows but has {static_matrix.shape[0]}!" + ) + allowed_codes = cfg.tabularization._resolved_codes + num_allowed_codes = len(allowed_codes) + feature_columns = get_feature_columns(cfg.tabularization.filtered_code_metadata_fp) + assert num_allowed_codes == len( + feature_columns + ), f"Should have {len(feature_columns)} codes but has {num_allowed_codes}" + + tabularize_time_series.main(cfg) + + # confirm summary files exist: + output_files = list_subdir_files(cfg.output_dir, "npz") + actual_files = [ + get_shard_prefix(Path(cfg.output_dir), each) + ".npz" + for each in output_files + if "none/static" not in str(each) + ] + assert set(actual_files) == set(SUMMARIZE_EXPECTED_FILES) + for f in output_files: + ts_matrix = load_matrix(f) + assert ts_matrix.shape[0] > 0, "Time-Series Tabular Dataframe Should not be Empty!" + expected_num_cols = len(get_feature_names(f"{f.parent.stem}/{f.stem}", feature_columns)) + assert ts_matrix.shape[1] == expected_num_cols, ( + f"Time-Series Tabular Dataframe Should have {expected_num_cols}" + f"Columns but has {ts_matrix.shape[1]}!" + ) + split = f.parts[-5] + shard_num = f.parts[-4] + med_shard_fp = (Path(cfg.input_dir) / split / shard_num).with_suffix(".parquet") + expected_num_rows = ( + get_unique_time_events_df(get_events_df(pl.scan_parquet(med_shard_fp), feature_columns)) + .collect() + .shape[0] + ) + assert ts_matrix.shape[0] == expected_num_rows, ( + f"Time-Series Data matrix Should have {expected_num_rows}" + f" rows but has {ts_matrix.shape[0]}!" + ) + + # Step 3: Cache Task data + cache_config = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + "tabularization.min_code_inclusion_frequency": 1, + "tabularization.aggs": "[static/present,static/first,code/count,value/sum]", + "tabularization.window_sizes": "[30d,365d,full]", + } + + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in cache_config.items()] + cfg = compose(config_name="task_specific_caching", overrides=overrides) # config.yaml + + # Create fake labels + for f in list_subdir_files(Path(cfg.MEDS_cohort_dir) / "final_cohort", "parquet"): + df = pl.scan_parquet(f) + df = get_unique_time_events_df(get_events_df(df, feature_columns)).collect() + pseudo_labels = pl.Series(([0, 1] * df.shape[0])[: df.shape[0]]) + df = df.with_columns(pl.Series(name="label", values=pseudo_labels)) + df = df.select(pl.col(["patient_id", "timestamp", "label"])) + df = df.with_row_index("event_id") + + split = f.parent.stem + shard_num = f.stem + out_f = Path(cfg.input_label_dir) / Path( + get_shard_prefix(Path(cfg.MEDS_cohort_dir) / "final_cohort", f) + ).with_suffix(".parquet") + out_f.parent.mkdir(parents=True, exist_ok=True) + df.write_parquet(out_f) + + cache_task.main(cfg) + + xgboost_config_kwargs = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + "tabularization.min_code_inclusion_frequency": 1, + "tabularization.aggs": "[static/present,static/first,code/count,value/sum]", + "tabularization.window_sizes": "[30d,365d,full]", + } + + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()] + cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml + + launch_xgboost.main(cfg) + output_files = list(Path(cfg.output_dir).glob("**/*.json")) + assert len(output_files) == 1 + assert output_files[0] == Path(cfg.output_dir) / "model.json" + os.remove(Path(cfg.output_dir) / "model.json") + + xgboost_config_kwargs = { + "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "do_overwrite": False, + "seed": 1, + "hydra.verbose": True, + "tqdm": False, + "loguru_init": True, + "tabularization.min_code_inclusion_frequency": 1, + "tabularization.aggs": "[static/present,static/first,code/count,value/sum]", + "tabularization.window_sizes": "[30d,365d,full]", + } + + with initialize( + version_base=None, config_path="../src/MEDS_tabular_automl/configs/" + ): # path to config.yaml + overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()] + cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml + + sweep_xgboost.main(cfg) + output_files = list(Path(cfg.output_dir).glob("**/*.json")) + assert len(output_files) == 1 + assert output_files[0] == Path(cfg.output_dir) / "model.json"