diff --git a/README.md b/README.md index 033c5c4..5a596d1 100644 --- a/README.md +++ b/README.md @@ -45,8 +45,7 @@ This repository consists of two key pieces: ### Scripts and Examples See `tests/test_tabularize_integration.py` for an example of the end-to-end pipeline being run on synthetic data. This -script is a functional test that is also run with `pytest` to verify correctness of the algorithm. - +script is a functional test that is also run with `pytest` to verify the correctness of the algorithm. #### Core Scripts: 1. `scripts/tabularize/identify_columns.py` loads all training shard to identify which feature columns diff --git a/scripts/hf_cohort_e2e.sh b/scripts/hf_cohort_e2e.sh index 2fbc235..3c39ea5 100644 --- a/scripts/hf_cohort_e2e.sh +++ b/scripts/hf_cohort_e2e.sh @@ -13,7 +13,7 @@ rm -rf $OUTPUT_DIR POLARS_MAX_THREADS=32 python scripts/identify_columns.py \ MEDS_cohort_dir=$MEDS_DIR \ tabularized_data_dir=$OUTPUT_DIR \ - min_code_inclusion_frequency=1 $WINDOW_SIZES do_overwrite=False $AGGS + min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS" echo "Running tabularize_static.py: tabularizing static data" POLARS_MAX_THREADS=32 python scripts/tabularize_static.py \ diff --git a/scripts/hf_cohort_shard.sh b/scripts/hf_cohort_shard.sh index f30878e..351ef3f 100644 --- a/scripts/hf_cohort_shard.sh +++ b/scripts/hf_cohort_shard.sh @@ -1,4 +1,4 @@ - +#!/usr/bin/env bash OUTPUT_DIR=/data/storage/shared/meds_tabular_ml/ebcl_dataset/processed PATIENTS_PER_SHARD="2500" CHUNKSIZE="200_000_000" diff --git a/src/MEDS_tabular_automl/utils.py b/src/MEDS_tabular_automl/utils.py index ef3c016..538db88 100644 --- a/src/MEDS_tabular_automl/utils.py +++ b/src/MEDS_tabular_automl/utils.py @@ -77,7 +77,7 @@ def get_static_col_dtype(col: str) -> pl.DataType: case "count" | "has_values_count": return pl.UInt32 case _: - raise ValueError(f"Column name {col} malformed!") + raise ValueError(f"Column name {col} malformed! Expected aggregations: 'sum', 'sum_sqd', 'min', 'max', 'value', 'first', 'present', 'count', 'has_values_count'.") def add_static_missing_cols( @@ -347,7 +347,7 @@ def setup_environment(cfg: DictConfig, load_data: bool = True): logger.info(f"Stored config: {stored_config}") logger.info(f"Worker config: {cfg}") assert cfg.keys() == stored_config.keys(), ( - f"Keys in stored config do not match current config.")`` + f"Keys in stored config do not match current config.") for key in cfg.keys(): assert key in stored_config, f"Key {key} not found in stored config." if key == "worker":