Skip to content

Commit

Permalink
fixed doctest for deleting below top k models
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed Sep 10, 2024
1 parent 8316365 commit f7e03dd
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/MEDS_tabular_automl/evaluation_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,37 @@ def log_performance(self, best_model_performance):
logger.info("\n".join(log_performance_message))

def delete_below_top_k_models(self, performance, k, sweep_results_dir):
"""Save only top k models from the model directory and delete all other files.
"""Save only top k models from the sweep results directory and delete all other directories.
Args:
performance: DataFrame containing trial_name and performance metrics.
k: Number of top models to save.
model_dir: Directory containing models.
sweep_results_dir: Directory containing trial results.
Example:
>>> import tempfile
>>> import json
>>> import polars as pl
>>> from pathlib import Path
>>> performance = pl.DataFrame(
... {
... "trial_name": ["model1", "model2", "model3", "model4"],
... "trial_name": ["trial1", "trial2", "trial3", "trial4"],
... "tuning_auc": [0.9, 0.8, 0.7, 0.6],
... "test_auc": [0.9, 0.8, 0.7, 0.6],
... }
... )
>>> k = 2
>>> with tempfile.TemporaryDirectory() as model_dir:
... for model in performance["trial_name"]:
... with open(Path(model_dir) / f"{model}.json", 'w') as f:
... json.dump({"model_name": model, "content": "dummy data"}, f)
>>> with tempfile.TemporaryDirectory() as sweep_dir:
... for trial in performance["trial_name"]:
... trial_dir = Path(sweep_dir) / trial
... trial_dir.mkdir()
... with open(trial_dir / "model.json", 'w') as f:
... json.dump({"model_name": trial, "content": "dummy data"}, f)
... cb = EvaluationCallback()
... cb.delete_below_top_k_models(performance, k, model_dir)
... remaining_models = sorted(p.stem for p in Path(model_dir).iterdir())
>>> remaining_models
['model1', 'model2']
... cb.delete_below_top_k_models(performance, k, sweep_dir)
... remaining_trials = sorted(p.name for p in Path(sweep_dir).iterdir())
>>> remaining_trials
['trial1', 'trial2']
"""
logger.info(f"Deleting all models except top {k} models.")
top_k_models = performance.head(k)["trial_name"].cast(pl.String).to_list()
Expand Down

0 comments on commit f7e03dd

Please sign in to comment.