Skip to content

Commit

Permalink
merging script runs, but the output is 50GB
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed May 31, 2024
1 parent e8f26eb commit 958906d
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 25 deletions.
26 changes: 26 additions & 0 deletions scripts/e2e.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env bash

MEDS_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort
OUTPUT_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/tabularize
N_PARALLEL_WORKERS="2" #"$3"

# echo "Running identify_columns.py: Caching feature names and frequencies."
# POLARS_MAX_THREADS=32 python scripts/identify_columns.py \
# MEDS_cohort_dir=$MEDS_DIR \
# tabularized_data_dir=$OUTPUT_DIR \
# min_code_inclusion_frequency=1 "window_sizes=[1d, 7d, 30d, 365d, full]" do_overwrite=True

# echo "Running tabularize_static.py: tabularizing static data"
# POLARS_MAX_THREADS=32 python scripts/tabularize_static.py \
# MEDS_cohort_dir=$MEDS_DIR \
# tabularized_data_dir=$OUTPUT_DIR \
# min_code_inclusion_frequency=1 "window_sizes=[1d, 7d, 30d, 365d, full]" do_overwrite=True

echo "Running summarize_over_windows.py with $N_PARALLEL_WORKERS workers in parallel"
POLARS_MAX_THREADS=1 python scripts/summarize_over_windows.py \
--multirun \
worker="range(1,$N_PARALLEL_WORKERS)" \
hydra/launcher=joblib \
MEDS_cohort_dir=$MEDS_DIR \
tabularized_data_dir=$OUTPUT_DIR \
min_code_inclusion_frequency=1 "window_sizes=[1d, 7d, 30d, 365d, full]" do_overwrite=True
49 changes: 29 additions & 20 deletions scripts/hf_cohort_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,47 @@

MEDS_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/final_cohort
OUTPUT_DIR=/storage/shared/meds_tabular_ml/ebcl_dataset/processed/tabularize
N_PARALLEL_WORKERS="$1"
# N_PARALLEL_WORKERS="$1"
WINDOW_SIZES="window_sizes=[1d]"
AGGS="aggs=[code/count,value/sum]"
# WINDOW_SIZES="window_sizes=[1d,7d,30d,365d,full]"
# AGGS="aggs=[static/present,static/first,code/count,value/count,value/sum,value/sum_sqd,value/min,value/max]"

echo "Running identify_columns.py: Caching feature names and frequencies."
rm -rf $OUTPUT_DIR
POLARS_MAX_THREADS=32 python scripts/identify_columns.py \
MEDS_cohort_dir=$MEDS_DIR \
tabularized_data_dir=$OUTPUT_DIR \
min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS"
# echo "Running identify_columns.py: Caching feature names and frequencies."
# rm -rf $OUTPUT_DIR
# POLARS_MAX_THREADS=32 python scripts/identify_columns.py \
# MEDS_cohort_dir=$MEDS_DIR \
# tabularized_data_dir=$OUTPUT_DIR \
# min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS"

echo "Running tabularize_static.py: tabularizing static data"
POLARS_MAX_THREADS=32 python scripts/tabularize_static.py \
MEDS_cohort_dir=$MEDS_DIR \
tabularized_data_dir=$OUTPUT_DIR \
min_code_inclusion_frequency=1 $WINDOW_SIZES do_overwrite=False $AGGS
# echo "Running tabularize_static.py: tabularizing static data"
# POLARS_MAX_THREADS=32 python scripts/tabularize_static.py \
# MEDS_cohort_dir=$MEDS_DIR \
# tabularized_data_dir=$OUTPUT_DIR \
# min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS"

# echo "Running summarize_over_windows.py with $N_PARALLEL_WORKERS workers in parallel"
# # echo "Running summarize_over_windows.py with $N_PARALLEL_WORKERS workers in parallel"
# # POLARS_MAX_THREADS=1 python scripts/summarize_over_windows.py \
# # --multirun \
# # worker="range(0,$N_PARALLEL_WORKERS)" \
# # hydra/launcher=joblib \
# # MEDS_cohort_dir=$MEDS_DIR \
# # tabularized_data_dir=$OUTPUT_DIR \
# # min_code_inclusion_frequency=1 do_overwrite=False \
# # "$WINDOW_SIZES" "$AGGS"

# echo "Running summarize_over_windows.py"
# POLARS_MAX_THREADS=1 python scripts/summarize_over_windows.py \
# --multirun \
# worker="range(0,$N_PARALLEL_WORKERS)" \
# hydra/launcher=joblib \
# MEDS_cohort_dir=$MEDS_DIR \
# tabularized_data_dir=$OUTPUT_DIR \
# min_code_inclusion_frequency=1 do_overwrite=False \
# $WINDOW_SIZES $AGGS
# "$WINDOW_SIZES" "$AGGS"


echo "Running summarize_over_windows.py"
POLARS_MAX_THREADS=1 python scripts/summarize_over_windows.py \
echo "Running tabularize_merge.py"
rm -r "$OUTPUT_DIR/sparse"
POLARS_MAX_THREADS=10 python /home/nassim/projects/MEDS_Tabular_AutoML/scripts/tabularize_merge.py \
MEDS_cohort_dir=$MEDS_DIR \
tabularized_data_dir=$OUTPUT_DIR \
min_code_inclusion_frequency=1 do_overwrite=False \
$WINDOW_SIZES $AGGS
"$WINDOW_SIZES" "$AGGS"
4 changes: 2 additions & 2 deletions scripts/tabularize_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def tabularize_ts_data(
out_subdir = flat_dir / "sparse"

for shard_fp in iter_wrapper(shard_fps):
split = shard_fp.parent.parent.parent.parent.stem
split = shard_fp.parts[-5]
in_ts_fp = shard_fp
assert in_ts_fp.exists(), f"{in_ts_fp} does not exist!"
in_static_fp = static_dir / split / f"{shard_fp.stem}.parquet"
assert in_static_fp.exists(), f"{in_static_fp} does not exist!"
out_fp = out_subdir / f"{shard_fp.stem}"
out_fp = out_subdir / "/".join(shard_fp.parts[-5:-1]) / f"{shard_fp.stem}"
out_fp.parent.mkdir(parents=True, exist_ok=True)

def read_fn(in_fps):
Expand Down
8 changes: 5 additions & 3 deletions src/MEDS_tabular_automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections.abc import Mapping
from pathlib import Path

import numpy as np
import pandas as pd
import polars as pl
import polars.selectors as cs
Expand Down Expand Up @@ -60,6 +61,8 @@ def write_df(df: DF_T, fp: Path, **kwargs):
f"Expected DataFrame to have columns ['patient_id', 'timestamp'], got {df.columns[:2]}"
)
df.to_pickle(fp)
elif isinstance(df, np.matrix):
np.save(fp, df)
else:
raise ValueError(f"Unsupported type for df: {type(df)}")

Expand All @@ -77,7 +80,7 @@ def get_static_col_dtype(col: str) -> pl.DataType:
case "count" | "has_values_count":
return pl.UInt32
case _:
raise ValueError(f"Column name {col} malformed! Expected aggregations: 'sum', 'sum_sqd', 'min', 'max', 'value', 'first', 'present', 'count', 'has_values_count'.")
raise ValueError(f"Column name {col} malformed!")


def add_static_missing_cols(
Expand Down Expand Up @@ -346,8 +349,7 @@ def setup_environment(cfg: DictConfig, load_data: bool = True):
stored_config = OmegaConf.create(yaml_config)
logger.info(f"Stored config: {stored_config}")
logger.info(f"Worker config: {cfg}")
assert cfg.keys() == stored_config.keys(), (
f"Keys in stored config do not match current config.")
assert cfg.keys() == stored_config.keys(), "Keys in stored config do not match current config."
for key in cfg.keys():
assert key in stored_config, f"Key {key} not found in stored config."
if key == "worker":
Expand Down

0 comments on commit 958906d

Please sign in to comment.