Skip to content

Commit

Permalink
generate-permutations cli for window_sizes and aggs
Browse files Browse the repository at this point in the history
  • Loading branch information
teyaberg committed Jun 10, 2024
1 parent d1d7bf0 commit df3e3aa
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 43 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ meds-tab-tabularize-static = "MEDS_tabular_automl.scripts.tabularize_static:main
meds-tab-tabularize-time-series = "MEDS_tabular_automl.scripts.tabularize_time_series:main"
meds-tab-cache-task = "MEDS_tabular_automl.scripts.cache_task:main"
meds-tab-xgboost = "MEDS_tabular_automl.scripts.launch_xgboost:main"
meds-tab-xgboost-sweep = "MEDS_tabular_automl.scripts.sweep_xgboost:main"
generate-permutations = "MEDS_tabular_automl.scripts.generate_permutations:main"


[project.optional-dependencies]
dev = ["pre-commit"]
Expand Down
47 changes: 47 additions & 0 deletions src/MEDS_tabular_automl/scripts/generate_permutations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import sys
from itertools import combinations


def format_print(permutations):
"""
Args:
permutations: List of all possible permutations of length > 1
Example:
>>> format_print([('2',), ('2', '3'), ('2', '3', '4'), ('2', '4'), ('3',), ('3', '4'), ('4',)])
[2],[2,3],[2,3,4],[2,4],[3],[3,4],[4]
"""
out_str = ""
for item in permutations:
out_str += f"[{','.join(item)}],"
out_str = out_str[:-1]
print(out_str)


def get_permutations(list_of_options):
"""Generate all possible permutations of a list of options passed as an arg.
Args:
- list_of_options (list): List of options.
Returns:
- list: List of all possible permutations of length > 1
Example:
>>> get_permutations(['2', '3', '4'])
[2],[2,3],[2,3,4],[2,4],[3],[3,4],[4]
"""
permutations = []
for i in range(1, len(list_of_options) + 1):
permutations.extend(list(combinations(list_of_options, r=i)))
format_print(sorted(permutations))


def main():
"""Generate all possible permutations of a list of options."""
list_of_options = list(sys.argv[1].strip("[]").split(","))
get_permutations(list_of_options)


if __name__ == "__main__":
main()
50 changes: 9 additions & 41 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def test_integration():
f"Time-Series Data matrix Should have {expected_num_rows}"
f" rows but has {ts_matrix.shape[0]}!"
)

# Step 4: Run the task_specific_caching script
cache_config = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
Expand All @@ -216,7 +215,7 @@ def test_integration():
"tqdm": False,
"loguru_init": True,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.aggs": "[static/present,static/first,code/count,value/sum]",
# "tabularization.aggs": "[static/present,static/first,code/count,value/sum]",
"tabularization.window_sizes": "[30d,365d,full]",
}
with initialize(
Expand All @@ -241,50 +240,19 @@ def test_integration():
out_f.parent.mkdir(parents=True, exist_ok=True)
df.write_parquet(out_f)

stderr, stdout = run_command(
"meds-tab-cache-task",
[],
cache_config,
"task_specific_caching",
stderr, stdout_ws = run_command(
"generate-permutations", ["[30d]"], {}, "generate-permutations window_sizes"
)
# Check the files are not empty

# Step 5: Run the xgboost script

xgboost_config_kwargs = {
"MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()),
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.aggs": "[static/present,static/first,code/count,value/sum]",
"tabularization.window_sizes": "[30d,365d,full]",
}
with initialize(
version_base=None, config_path="../src/MEDS_tabular_automl/configs/"
): # path to config.yaml
overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()]
cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml
stderr, stdout = run_command(
"meds-tab-xgboost",
[],
xgboost_config_kwargs,
"xgboost",
stderr, stdout_agg = run_command(
"generate-permutations", ["[static/present,static/first]"], {}, "generate-permutations aggs"
)
output_files = list(Path(cfg.output_dir).parent.glob("**/*.json"))
assert len(output_files) == 1
# assert output_files[0].stem == '0.6667_model'

stderr, stdout = run_command(
"meds-tab-xgboost",
"meds-tab-cache-task",
[
"--multirun",
f"tabularization.aggs={stdout_agg.strip()}",
],
xgboost_config_kwargs,
"xgboost-sweep",
cache_config,
"task_specific_caching",
)
output_files = list(Path(cfg.output_dir).parent.glob("**/*.json"))
assert len(output_files) == 11
# assert output_files[0].stem == "model"
40 changes: 39 additions & 1 deletion tests/test_tabularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

import json
import os
import subprocess
import tempfile
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -391,3 +391,41 @@ def test_tabularize():
launch_xgboost.main(cfg)
output_files = list(Path(cfg.output_dir).glob("**/*.json"))
assert len(output_files) == 1


def run_command(script: str, args: list[str], hydra_kwargs: dict[str, str], test_name: str):
command_parts = [script] + args + [f"{k}={v}" for k, v in hydra_kwargs.items()]
command_out = subprocess.run(" ".join(command_parts), shell=True, capture_output=True)
stderr = command_out.stderr.decode()
stdout = command_out.stdout.decode()
if command_out.returncode != 0:
raise AssertionError(f"{test_name} failed!\nstdout:\n{stdout}\nstderr:\n{stderr}")
return stderr, stdout


def test_xgboost_config():
MEDS_cohort_dir = "blah"
stderr, stdout_ws = run_command(
"generate-permutations", ["[30d]"], {}, "generate-permutations window_sizes"
)
stderr, stdout_agg = run_command(
"generate-permutations", ["[static/present]"], {}, "generate-permutations aggs"
)
xgboost_config_kwargs = {
"MEDS_cohort_dir": MEDS_cohort_dir,
"do_overwrite": False,
"seed": 1,
"hydra.verbose": True,
"tqdm": False,
"loguru_init": True,
"tabularization.min_code_inclusion_frequency": 1,
"tabularization.window_sizes": f"{stdout_ws.strip()}",
}

with initialize(
version_base=None, config_path="../src/MEDS_tabular_automl/configs/"
): # path to config.yaml
overrides = [f"{k}={v}" for k, v in xgboost_config_kwargs.items()]
cfg = compose(config_name="launch_xgboost", overrides=overrides) # config.yaml
print(cfg.tabularization.window_sizes)
assert cfg.tabularization.window_sizes

0 comments on commit df3e3aa

Please sign in to comment.