Skip to content

Commit

Permalink
added script for extracting tasks using aces
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed Jun 1, 2024
1 parent b6b8d43 commit e8d64fd
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 26 deletions.
13 changes: 13 additions & 0 deletions hf_cohort/aces_task.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env bash

MEDS_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort
OUTPUT_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/tabularize
# N_PARALLEL_WORKERS="$1"
WINDOW_SIZES="window_sizes=[1d]"
AGGS="aggs=[code/count,value/sum]"

python /home/nassim/projects/MEDS_Tabular_AutoML/hf_cohort/aces_task_extraction.py \
MEDS_cohort_dir=$MEDS_DIR \
tabularized_data_dir=$OUTPUT_DIR \
min_code_inclusion_frequency=1 do_overwrite=False \
"$WINDOW_SIZES" "$AGGS"
51 changes: 51 additions & 0 deletions hf_cohort/aces_task_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Setup Conda environment as described here: https://github.com/justin13601/ACES
"""
from pathlib import Path

import hydra
import polars as pl
from aces import config, predicates, query
from tqdm import tqdm


@hydra.main(version_base=None, config_path="../configs", config_name="tabularize")
def main(cfg):
# create task configuration object
task_cfg = config.TaskExtractorConfig.load(config_path="hf_cohort/task.yaml")

# setup directories
med_dir = Path(cfg.tabularized_data_dir)

# location of MEDS format Data
cohort_dir = med_dir.parent / "final_cohort"
# output directory for tables with event_ids and labels
output_dir = med_dir / "task"

shard_fps = list(cohort_dir.glob("*/*.parquet"))

for in_fp in tqdm(shard_fps):
out_fp = output_dir / "/".join(in_fp.parts[-2:])
out_fp.parent.mkdir(parents=True, exist_ok=True)
# one of the following
predicates_df = predicates.generate_predicates_df(task_cfg, in_fp, "meds")

# execute query
df_result = query.query(task_cfg, predicates_df)
label_df = (
df_result.select(pl.col(["subject_id", "trigger", "label"]))
.rename({"trigger": "timestamp", "subject_id": "patient_id"})
.sort(by=["patient_id", "timestamp"])
)
data_df = pl.scan_parquet(in_fp)
data_df = data_df.unique(subset=["patient_id", "timestamp"]).sort(by=["patient_id", "timestamp"])
data_df = data_df.with_row_index("event_id")
data_df = data_df.drop(["code", "numerical_value"])
output_df = label_df.lazy().join_asof(other=data_df, by="patient_id", on="timestamp")

# store it
output_df.collect().write_parquet(out_fp)


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions hf_cohort/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Path to the task configuration file
config_path: task.yaml

# Raw Data
data:
# Path to the data file or directory
path: /storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort/train/0.parquet

# Data standard, one of (csv, meds, esgpt)
standard: meds

# Output Directory (saves as .parquet file)
output_dir: results/

# Hydra
hydra:
job:
name: ACES_${now:%Y-%m-%d_%H-%M-%S}
run:
dir: ${ACES_dir}/.logs/${hydra.job.name}
# aces-cli --config-dir='./' --config-name='config.yaml'
File renamed without changes.
File renamed without changes.
21 changes: 21 additions & 0 deletions hf_cohort/task.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Task: 30-day Readmission Risk Prediction
predicates:
admission:
code: ADMIT_DATE
discharge:
code: DISCHARGE_DATE

trigger: admission

windows:
input:
start: trigger
end: start -> discharge
start_inclusive: False
end_inclusive: True
target:
start: input.end
end: start + 30 days
start_inclusive: False
end_inclusive: True
label: admission
26 changes: 0 additions & 26 deletions scripts/e2e.sh

This file was deleted.

0 comments on commit e8d64fd

Please sign in to comment.