-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Converting Esgpt caching to work for MEDS datasets #1
Changes from 61 commits
d5ff0df
4a486aa
19f0f4e
fd1731f
63b9ba6
cd067f8
c8ca3bb
7fdc37d
4dd3cad
720a533
548e29a
d9ba7e7
f0b1cbb
ba954ef
df2750a
4bbbc20
d39bf1a
41fe4b4
cb5f689
8bc9a16
1e27526
97938a8
c28e6b2
6f3b1ec
6753609
f125600
2acc3bc
eec05e2
bd9bdae
29c8c5f
4c7d3e7
ba796e5
3678d30
82b3903
ffa0f3c
c8f26ea
f6a3751
2ec1860
db18dc5
abba3d2
77f296f
e8f26eb
958906d
7668382
b6b8d43
e6a88a7
e8d64fd
5b2f7f7
5c5dc8e
357845e
d99e274
cadc603
285ccbf
795b532
b9d057b
85f38b5
7ea3230
23a2e3b
c225c47
cb21821
3a412a0
a4f1843
800ab7e
820e194
4b0637a
127d04a
23877ad
36f54a3
81bf2d9
83c4eec
e7a85ba
35acb97
bef63b6
e9775e2
c8f4144
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
# Scalable tabularization and tabular feature usage utilities over generic MEDS datasets | ||
|
||
This repository provides utilities and scripts to run limited automatic tabular ML pipelines for generic MEDS | ||
datasets. | ||
|
||
#### Q1: What do you mean "tabular pipelines"? Isn't _all_ structured EHR data already tabular? | ||
|
||
This is a common misconception. _Tabular_ data refers to data that can be organized in a consistent, logical | ||
set of rows/columns such that the entirety of a "sample" or "instance" for modeling or analysis is contained | ||
in a single row, and the set of columns possibly observed (there can be missingness) is consistent across all | ||
|
@@ -15,28 +17,62 @@ or future windows in time to produce a single row per patient with a consistent, | |
(though there may still be missingness). | ||
|
||
#### Q2: Why not other systems? | ||
- [TemporAI](https://github.com/vanderschaarlab/temporai) is the most natural competitor, and already | ||
supports AutoML capabilities. However, TemporAI (as of now) does not support generic MEDS datasets, and it | ||
is not clear if their AutoML systems will scale to the size of datasets we need to support. But, further | ||
investigation is needed, and it may be the case that the best solution here is simply to write a custom | ||
data source for MEDS data within TemporAI and leverage their tools. | ||
|
||
- [TemporAI](https://github.com/vanderschaarlab/temporai) is the most natural competitor, and already | ||
supports AutoML capabilities. However, TemporAI (as of now) does not support generic MEDS datasets, and it | ||
is not clear if their AutoML systems will scale to the size of datasets we need to support. But, further | ||
investigation is needed, and it may be the case that the best solution here is simply to write a custom | ||
data source for MEDS data within TemporAI and leverage their tools. | ||
|
||
# Installation | ||
|
||
Clone this repository and install the requirements by running `pip install .` in the root directory. | ||
|
||
# Usage | ||
|
||
This repository consists of two key pieces: | ||
1. Construction of and efficient loading of tabular (flat, non-longitudinal) summary features describing | ||
patient records in MEDS over arbitrary time-windows (e.g. 1 year, 6 months, etc.) either backwards or | ||
forwards in time from a given index date. Naturally, only "look-back" windows should be used for | ||
future-event prediction tasks; however, the capability to summarize "look-ahead" windows is also useful | ||
for characterizing and describing the differences between patient populations statistically. | ||
2. Running basic AutoML pipelines over these tabular features to predict arbitrary binary classification | ||
downstream tasks defined over these datasets. The "AutoML" part of this is not particularly advanced -- | ||
what is more advanced is the efficient construction, storage, and loading of tabular features for the | ||
candidate AutoML models, enabling a far more extensive search over different featurization strategies. | ||
|
||
1. Construction of and efficient loading of tabular (flat, non-longitudinal) summary features describing | ||
patient records in MEDS over arbitrary time-windows (e.g. 1 year, 6 months, etc.) either backwards or | ||
forwards in time from a given index date. Naturally, only "look-back" windows should be used for | ||
future-event prediction tasks; however, the capability to summarize "look-ahead" windows is also useful | ||
for characterizing and describing the differences between patient populations statistically. | ||
2. Running basic AutoML pipelines over these tabular features to predict arbitrary binary classification | ||
downstream tasks defined over these datasets. The "AutoML" part of this is not particularly advanced -- | ||
what is more advanced is the efficient construction, storage, and loading of tabular features for the | ||
candidate AutoML models, enabling a far more extensive search over different featurization strategies. | ||
|
||
### Scripts and Examples | ||
|
||
See `tests/test_tabularize_integration.py` for an example of the end-to-end pipeline being run on synthetic data. This | ||
script is a functional test that is also run with `pytest` to verify the correctness of the algorithm. | ||
|
||
#### Core Scripts: | ||
|
||
1. `scripts/identify_columns.py` loads all training shard to identify which feature columns | ||
to generate tabular data for. | ||
2. `scripts/tabularize_static.py` Iterates through shards and generates tabular vectors for | ||
each patient. There is a single row per patient for each shard. | ||
3. `scripts/summarize_over_windows.py` For each shard, iterates through window sizes and aggregations to and | ||
horizontally concatenates the outputs to generate the final tabular representations at every event time for | ||
every patient. | ||
4. `scripts/tabularize_merge` Aligns the time-series window aggregations (generated in the previous step) with | ||
the static tabular vectors and caches them for training. | ||
5. `scripts/hf_cohort/aces_task_extraction.py` Generates the task labels and caches them with the event_id | ||
indexes which align them with the nearest prior event in the tabular data. | ||
6. `scripts/xgboost_sweep.py` Tunes XGboost on methods. Iterates through the labels and corresponding tabular data. | ||
|
||
We run this on an example dataset using the following bash scripts in sequence: | ||
|
||
```bash | ||
bash hf_cohort_shard.sh # processes the dataset into meds format | ||
bash hf_cohort_e2e.sh # performs (steps 1-4 above) | ||
bash hf_cohort/aces_task.sh # generates labels (step 5) | ||
bash xgboost.sh # trains xgboos (step 6) | ||
``` | ||
|
||
## Feature Construction, Storage, and Loading | ||
|
||
Tabularization of a (raw) MEDS dataset is done by running the `scripts/data/tabularize.py` script. This script | ||
must inherently do a base level of preprocessing over the MEDS data, then will construct a sharded tabular | ||
representation that respects the overall sharding of the raw data. This script uses [Hydra](https://hydra.cc/) | ||
|
@@ -45,14 +81,39 @@ to manage configuration, and the configuration file is located at `configs/tabul | |
## AutoML Pipelines | ||
|
||
# TODOs | ||
mmcdermott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
1. Leverage the "event bound aggregation" capabilities of [ESGPT Task | ||
Select](https://github.com/justin13601/ESGPTTaskQuerying/) to construct tabular summary features for | ||
event-bound historical windows (e.g., until the prior admission, until the last diagnosis of some type, | ||
etc.). | ||
2. Support more feature aggregation functions. | ||
3. Probably rename this repository, as the focus is really more on the tabularization and feature usage | ||
utilities than on the AutoML pipelines themselves. | ||
4. Import, rather than reimplement, the mapper utilities from the MEDS preprocessing repository. | ||
5. Investigate the feasibility of using TemporAI for this task. | ||
6. Consider splitting the feature construction and AutoML pipeline parts of this repository into separate | ||
repositories. | ||
|
||
1. Leverage the "event bound aggregation" capabilities of [ESGPT Task | ||
Select](https://github.com/justin13601/ESGPTTaskQuerying/) to construct tabular summary features for | ||
event-bound historical windows (e.g., until the prior admission, until the last diagnosis of some type, | ||
etc.). | ||
2. Support more feature aggregation functions. | ||
3. Probably rename this repository, as the focus is really more on the tabularization and feature usage | ||
utilities than on the AutoML pipelines themselves. | ||
4. Import, rather than reimplement, the mapper utilities from the MEDS preprocessing repository. | ||
5. Investigate the feasibility of using TemporAI for this task. | ||
6. Consider splitting the feature construction and AutoML pipeline parts of this repository into separate | ||
repositories. | ||
|
||
Comment on lines
+84
to
+96
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Address the TODOs to ensure they are actively tracked and prioritized. Would you like me to help create GitHub issues for these TODOs to ensure they are not overlooked? |
||
# YAML Configuration File | ||
|
||
- `MEDS_cohort_dir`: directory of MEDS format dataset that is ingested. | ||
- `tabularized_data_dir`: output directory of tabularized data. | ||
- `min_code_inclusion_frequency`: The base feature inclusion frequency that should be used to dictate | ||
what features can be included in the flat representation. It can either be a float, in which | ||
case it applies across all measurements, or `None`, in which case no filtering is applied, or | ||
a dictionary from measurement type to a float dictating a per-measurement-type inclusion | ||
cutoff. | ||
- `window_sizes`: Beyond writing out a raw, per-event flattened representation, the dataset also has | ||
the capability to summarize these flattened representations over the historical windows | ||
specified in this argument. These are strings specifying time deltas, using this syntax: | ||
`link`\_. Each window size will be summarized to a separate directory, and will share the same | ||
subject file split as is used in the raw representation files. | ||
- `codes`: A list of codes to include in the flat representation. If `None`, all codes will be included | ||
in the flat representation. | ||
- `aggs`: A list of aggregations to apply to the raw representation. Must have length greater than 0. | ||
- `n_patients_per_sub_shard`: The number of subjects that should be included in each output file. | ||
Lowering this number increases the number of files written, making the process of creating and | ||
leveraging these files slower but more memory efficient. | ||
- `do_overwrite`: If `True`, this function will overwrite the data already stored in the target save | ||
directory. | ||
- `seed`: The seed to use for random number generation. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Raw data | ||
MEDS_cohort_dir: ??? | ||
tabularized_data_dir: ??? | ||
model_dir: ${tabularized_data_dir}/model | ||
|
||
# Pre-processing | ||
min_code_inclusion_frequency: 1 | ||
window_sizes: [1d] | ||
codes: null | ||
aggs: | ||
- "code/count" | ||
- "value/sum" | ||
|
||
dynamic_threshold: 0.01 | ||
numerical_value_threshold: 0.1 | ||
|
||
# Sharding | ||
n_patients_per_sub_shard: null | ||
|
||
# Misc | ||
do_overwrite: False | ||
do_update: True | ||
seed: 1 | ||
tqdm: True | ||
|
||
model: | ||
booster: gbtree | ||
device: cpu | ||
tree_method: hist | ||
objective: reg:squarederror | ||
|
||
iterator: | ||
keep_data_in_memory: False | ||
|
||
# Hydra settings for sweep | ||
defaults: | ||
- override hydra/sweeper: optuna | ||
- override hydra/sweeper/sampler: tpe | ||
|
||
hydra: | ||
verbose: False | ||
sweep: | ||
dir: ${tabularized_data_dir}/.logs/etl/${now:%Y-%m-%d_%H-%M-%S} | ||
run: | ||
dir: ${tabularized_data_dir}/.logs/etl/${now:%Y-%m-%d_%H-%M-%S} | ||
|
||
# Optuna Sweeper | ||
sweeper: | ||
sampler: | ||
seed: 1 | ||
storage: null | ||
study_name: tabularize_study_${now:%Y-%m-%d_%H-%M-%S} | ||
direction: minimize | ||
n_trials: 10 | ||
|
||
# Define search space for Optuna | ||
params: | ||
window_sizes: choice([30d, 365d, full], [30d, full], [30d]) | ||
# iterator.keep_static_data_in_memory: choice([True], [False]) | ||
# iterator.keep_data_in_memory: choice([True], [False]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#!/usr/bin/env bash | ||
|
||
MEDS_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort | ||
OUTPUT_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/tabularize | ||
# N_PARALLEL_WORKERS="$1" | ||
WINDOW_SIZES="window_sizes=[1d]" | ||
AGGS="aggs=[code/count,value/sum]" | ||
|
||
python /home/nassim/projects/MEDS_Tabular_AutoML/hf_cohort/aces_task_extraction.py \ | ||
MEDS_cohort_dir=$MEDS_DIR \ | ||
tabularized_data_dir=$OUTPUT_DIR \ | ||
min_code_inclusion_frequency=1 do_overwrite=False \ | ||
"$WINDOW_SIZES" "$AGGS" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
Setup Conda environment as described here: https://github.com/justin13601/ACES | ||
""" | ||
from pathlib import Path | ||
|
||
import hydra | ||
import polars as pl | ||
from aces import config, predicates, query | ||
from tqdm import tqdm | ||
|
||
|
||
@hydra.main(version_base=None, config_path="../configs", config_name="tabularize") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tip Codebase Verification The Hydra configuration file
Analysis chainEnsure the Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Verify the Hydra configuration paths.
# Test: Check if the specified configuration files exist.
if [ -f "../configs/tabularize.yaml" ]; then
echo "Hydra configuration file exists."
else
echo "Hydra configuration file does not exist."
fi
Length of output: 130 Script: #!/bin/bash
# Description: Search for YAML configuration files in the repository.
# Find all YAML files in the repository.
fd --extension yaml
Length of output: 87 Script: #!/bin/bash
# Description: Check the current working directory when the script is executed.
# Print the current working directory.
pwd
Length of output: 24 |
||
def main(cfg): | ||
# create task configuration object | ||
task_cfg = config.TaskExtractorConfig.load(config_path="hf_cohort/task.yaml") | ||
|
||
# setup directories | ||
med_dir = Path(cfg.tabularized_data_dir) | ||
|
||
# location of MEDS format Data | ||
cohort_dir = med_dir.parent / "final_cohort" | ||
# output directory for tables with event_ids and labels | ||
output_dir = med_dir / "task" | ||
|
||
shard_fps = list(cohort_dir.glob("*/*.parquet")) | ||
|
||
for in_fp in tqdm(shard_fps): | ||
out_fp = output_dir / "/".join(in_fp.parts[-2:]) | ||
out_fp.parent.mkdir(parents=True, exist_ok=True) | ||
# one of the following | ||
predicates_df = predicates.generate_predicates_df(task_cfg, in_fp, "meds") | ||
|
||
# execute query | ||
df_result = query.query(task_cfg, predicates_df) | ||
label_df = ( | ||
df_result.select(pl.col(["subject_id", "trigger", "label"])) | ||
.rename({"trigger": "timestamp", "subject_id": "patient_id"}) | ||
.sort(by=["patient_id", "timestamp"]) | ||
) | ||
data_df = pl.scan_parquet(in_fp) | ||
data_df = data_df.unique(subset=["patient_id", "timestamp"]).sort(by=["patient_id", "timestamp"]) | ||
data_df = data_df.with_row_index("event_id") | ||
data_df = data_df.drop(["code", "numerical_value"]) | ||
output_df = label_df.lazy().join_asof(other=data_df, by="patient_id", on="timestamp") | ||
|
||
# store it | ||
output_df.collect().write_parquet(out_fp) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Path to the task configuration file | ||
config_path: task.yaml | ||
|
||
# Raw Data | ||
data: | ||
# Path to the data file or directory | ||
path: /storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort/train/0.parquet | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tip Codebase Verification The specified path in the configuration file is not accessible.
Analysis chainEnsure the path is correct and accessible. Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Verify the accessibility of the specified data path.
# Test: Check if the path exists and is accessible.
if [ -d "/storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort/train/0.parquet" ]; then
echo "Path is accessible."
else
echo "Path is not accessible."
fi
Length of output: 153 Script: #!/bin/bash
# Description: Verify the accessibility of the specified data path.
# Test: Check if the path exists and is accessible as a file.
if [ -f "/storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort/train/0.parquet" ]; then
echo "Path is accessible."
else
echo "Path is not accessible."
fi
Length of output: 153 |
||
|
||
# Data standard, one of (csv, meds, esgpt) | ||
standard: meds | ||
|
||
# Output Directory (saves as .parquet file) | ||
output_dir: results/ | ||
|
||
# Hydra | ||
hydra: | ||
job: | ||
name: ACES_${now:%Y-%m-%d_%H-%M-%S} | ||
run: | ||
dir: ${ACES_dir}/.logs/${hydra.job.name} | ||
# aces-cli --config-dir='./' --config-name='config.yaml' |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,32 @@ | ||||||||||||||||||||||||||||||||||||||
#!/usr/bin/env bash | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
MEDS_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed | ||||||||||||||||||||||||||||||||||||||
OUTPUT_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/tabularize | ||||||||||||||||||||||||||||||||||||||
# N_PARALLEL_WORKERS="$1" | ||||||||||||||||||||||||||||||||||||||
WINDOW_SIZES="window_sizes=[1d]" | ||||||||||||||||||||||||||||||||||||||
AGGS="aggs=[code/count,value/sum]" | ||||||||||||||||||||||||||||||||||||||
# WINDOW_SIZES="window_sizes=[1d,7d,30d,365d,full]" | ||||||||||||||||||||||||||||||||||||||
# AGGS="aggs=[static/present,static/first,code/count,value/count,value/sum,value/sum_sqd,value/min,value/max]" | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
echo "Running identify_columns.py: Caching feature names and frequencies." | ||||||||||||||||||||||||||||||||||||||
rm -rf $OUTPUT_DIR | ||||||||||||||||||||||||||||||||||||||
POLARS_MAX_THREADS=32 python scripts/identify_columns.py \ | ||||||||||||||||||||||||||||||||||||||
MEDS_cohort_dir=$MEDS_DIR \ | ||||||||||||||||||||||||||||||||||||||
tabularized_data_dir=$OUTPUT_DIR \ | ||||||||||||||||||||||||||||||||||||||
min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS" | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+13
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding error handling for the The script lacks error handling for critical operations like removing directories and running Python scripts. Adding error checks will improve the robustness of the script. + if ! rm -rf $OUTPUT_DIR; then
+ echo "Failed to remove $OUTPUT_DIR"
+ exit 1
+ fi
+ if ! POLARS_MAX_THREADS=32 python scripts/identify_columns.py \
+ MEDS_cohort_dir=$MEDS_DIR \
+ tabularized_data_dir=$OUTPUT_DIR \
+ min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS"; then
+ echo "Python script failed"
+ exit 1
+ fi Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
echo "Running tabularize_static.py: tabularizing static data" | ||||||||||||||||||||||||||||||||||||||
POLARS_MAX_THREADS=32 python scripts/tabularize_static.py \ | ||||||||||||||||||||||||||||||||||||||
MEDS_cohort_dir=$MEDS_DIR \ | ||||||||||||||||||||||||||||||||||||||
tabularized_data_dir=$OUTPUT_DIR \ | ||||||||||||||||||||||||||||||||||||||
min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS" | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
echo "Running summarize_over_windows.py with $N_PARALLEL_WORKERS workers in parallel" | ||||||||||||||||||||||||||||||||||||||
POLARS_MAX_THREADS=1 python scripts/summarize_over_windows.py \ | ||||||||||||||||||||||||||||||||||||||
--multirun \ | ||||||||||||||||||||||||||||||||||||||
worker="range(0,$N_PARALLEL_WORKERS)" \ | ||||||||||||||||||||||||||||||||||||||
hydra/launcher=joblib \ | ||||||||||||||||||||||||||||||||||||||
MEDS_cohort_dir=$MEDS_DIR \ | ||||||||||||||||||||||||||||||||||||||
tabularized_data_dir=$OUTPUT_DIR \ | ||||||||||||||||||||||||||||||||||||||
min_code_inclusion_frequency=1 do_overwrite=False \ | ||||||||||||||||||||||||||||||||||||||
"$WINDOW_SIZES" "$AGGS" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify the usage instructions to ensure they are easy to follow.