From 2ec18604a2a339eeba3474daa6cc93a4c9cc1cd9 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 31 May 2024 10:55:32 -0400 Subject: [PATCH] Apply suggestions from code review Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- README.md | 3 +-- scripts/hf_cohort_e2e.sh | 2 +- scripts/hf_cohort_shard.sh | 2 +- src/MEDS_tabular_automl/utils.py | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 033c5c4..5a596d1 100644 --- a/README.md +++ b/README.md @@ -45,8 +45,7 @@ This repository consists of two key pieces: ### Scripts and Examples See `tests/test_tabularize_integration.py` for an example of the end-to-end pipeline being run on synthetic data. This -script is a functional test that is also run with `pytest` to verify correctness of the algorithm. - +script is a functional test that is also run with `pytest` to verify the correctness of the algorithm. #### Core Scripts: 1. `scripts/tabularize/identify_columns.py` loads all training shard to identify which feature columns diff --git a/scripts/hf_cohort_e2e.sh b/scripts/hf_cohort_e2e.sh index 2fbc235..3c39ea5 100644 --- a/scripts/hf_cohort_e2e.sh +++ b/scripts/hf_cohort_e2e.sh @@ -13,7 +13,7 @@ rm -rf $OUTPUT_DIR POLARS_MAX_THREADS=32 python scripts/identify_columns.py \ MEDS_cohort_dir=$MEDS_DIR \ tabularized_data_dir=$OUTPUT_DIR \ - min_code_inclusion_frequency=1 $WINDOW_SIZES do_overwrite=False $AGGS + min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS" echo "Running tabularize_static.py: tabularizing static data" POLARS_MAX_THREADS=32 python scripts/tabularize_static.py \ diff --git a/scripts/hf_cohort_shard.sh b/scripts/hf_cohort_shard.sh index f30878e..351ef3f 100644 --- a/scripts/hf_cohort_shard.sh +++ b/scripts/hf_cohort_shard.sh @@ -1,4 +1,4 @@ - +#!/usr/bin/env bash OUTPUT_DIR=/data/storage/shared/meds_tabular_ml/ebcl_dataset/processed PATIENTS_PER_SHARD="2500" CHUNKSIZE="200_000_000" diff --git a/src/MEDS_tabular_automl/utils.py b/src/MEDS_tabular_automl/utils.py index ef3c016..538db88 100644 --- a/src/MEDS_tabular_automl/utils.py +++ b/src/MEDS_tabular_automl/utils.py @@ -77,7 +77,7 @@ def get_static_col_dtype(col: str) -> pl.DataType: case "count" | "has_values_count": return pl.UInt32 case _: - raise ValueError(f"Column name {col} malformed!") + raise ValueError(f"Column name {col} malformed! Expected aggregations: 'sum', 'sum_sqd', 'min', 'max', 'value', 'first', 'present', 'count', 'has_values_count'.") def add_static_missing_cols( @@ -347,7 +347,7 @@ def setup_environment(cfg: DictConfig, load_data: bool = True): logger.info(f"Stored config: {stored_config}") logger.info(f"Worker config: {cfg}") assert cfg.keys() == stored_config.keys(), ( - f"Keys in stored config do not match current config.")`` + f"Keys in stored config do not match current config.") for key in cfg.keys(): assert key in stored_config, f"Key {key} not found in stored config." if key == "worker":