Skip to content

Commit

Permalink
Merge branch 'esgpt_caching' of github.com:mmcdermott/MEDS_Tabular_Au…
Browse files Browse the repository at this point in the history
…toML into esgpt_caching
  • Loading branch information
Nassim Oufattole committed May 31, 2024
2 parents 77f296f + 2ec1860 commit e8f26eb
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ This repository consists of two key pieces:
### Scripts and Examples

See `tests/test_tabularize_integration.py` for an example of the end-to-end pipeline being run on synthetic data. This
script is a functional test that is also run with `pytest` to verify correctness of the algorithm.

script is a functional test that is also run with `pytest` to verify the correctness of the algorithm.
#### Core Scripts:

1. `scripts/tabularize/identify_columns.py` loads all training shard to identify which feature columns
Expand Down
2 changes: 1 addition & 1 deletion scripts/hf_cohort_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ rm -rf $OUTPUT_DIR
POLARS_MAX_THREADS=32 python scripts/identify_columns.py \
MEDS_cohort_dir=$MEDS_DIR \
tabularized_data_dir=$OUTPUT_DIR \
min_code_inclusion_frequency=1 $WINDOW_SIZES do_overwrite=False $AGGS
min_code_inclusion_frequency=1 "$WINDOW_SIZES" do_overwrite=False "$AGGS"

echo "Running tabularize_static.py: tabularizing static data"
POLARS_MAX_THREADS=32 python scripts/tabularize_static.py \
Expand Down
2 changes: 1 addition & 1 deletion scripts/hf_cohort_shard.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

#!/usr/bin/env bash
OUTPUT_DIR=/data/storage/shared/meds_tabular_ml/ebcl_dataset/processed
PATIENTS_PER_SHARD="2500"
CHUNKSIZE="200_000_000"
Expand Down
4 changes: 2 additions & 2 deletions src/MEDS_tabular_automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_static_col_dtype(col: str) -> pl.DataType:
case "count" | "has_values_count":
return pl.UInt32
case _:
raise ValueError(f"Column name {col} malformed!")
raise ValueError(f"Column name {col} malformed! Expected aggregations: 'sum', 'sum_sqd', 'min', 'max', 'value', 'first', 'present', 'count', 'has_values_count'.")


def add_static_missing_cols(
Expand Down Expand Up @@ -347,7 +347,7 @@ def setup_environment(cfg: DictConfig, load_data: bool = True):
logger.info(f"Stored config: {stored_config}")
logger.info(f"Worker config: {cfg}")
assert cfg.keys() == stored_config.keys(), (
f"Keys in stored config do not match current config.")``
f"Keys in stored config do not match current config.")
for key in cfg.keys():
assert key in stored_config, f"Key {key} not found in stored config."
if key == "worker":
Expand Down

0 comments on commit e8f26eb

Please sign in to comment.