Skip to content

Commit

Permalink
eval callback
Browse files Browse the repository at this point in the history
  • Loading branch information
teyaberg committed Sep 7, 2024
1 parent 0623aaa commit 0e985ee
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 32 deletions.
51 changes: 22 additions & 29 deletions src/MEDS_tabular_automl/evaluation_callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
from pathlib import Path

import polars as pl
Expand All @@ -16,51 +15,45 @@ def on_multirun_end(self, config: DictConfig, **kwargs):
log_fp = Path(config.model_logging.model_log_dir)

try:
performance = pl.read_csv(log_fp / "*/*.csv")
perf = pl.read_csv(log_fp / f"*/*{config.model_logging.performance_log_stem}.log")
except Exception as e:
raise FileNotFoundError(f"Log files incomplete or not found at {log_fp}, exception {e}.")

performance = performance.sort("tuning_auc", descending=True, nulls_last=True)
logger.info(performance.head(10))
perf = perf.sort("tuning_auc", descending=True, nulls_last=True)
logger.info(f"\nPerformance of the top 10 models:\n{perf.head(10)}")

# get best model_fp
best_model = performance[0, 0]

best_params_fp = log_fp / best_model / f"{config.model_logging.config_log_stem}.json"

# check if this file exists
if not best_params_fp.is_file():
raise FileNotFoundError(f"Best hyperparameters file not found at {best_params_fp}")
best_model = perf[0, 0]

logger.info(f"The best model can be found at {best_model}")
# self.log_performance(performance.head(1))
# self.log_hyperparams(best_hyperparams)
self.log_performance(perf[0, :])
# self.log_hyperparams(log_fp / best_model / f"{config.model_logging.config_log_stem}.log")
if hasattr(config, "model_saving.delete_below_top_k") and config.delete_below_top_k >= 0:
self.delete_below_top_k_models(
performance, config.model_saving.delete_below_top_k, config.model_saving.model_dir
perf, config.model_saving.delete_below_top_k, config.model_saving.model_dir
)

return performance.head(1)
return perf.head(1)

def log_performance(self, performance):
def log_performance(self, perf):
"""logger.info performance of the best model with nice formatting."""
logger.info("Performance of the best model:")
logger.info(f"Tuning AUC: {performance['tuning_auc'].values[0]}")
logger.info(f"Test AUC: {performance['test_auc'].values[0]}")

def log_hyperparams(self, hyperparams):
"""logger.info hyperparameters of the best model with nice formatting."""
logger.info("Hyperparameters of the best model:")
logger.info(
f"Tabularization: {OmegaConf.to_yaml(ast.literal_eval(hyperparams['tabularization'].values[0]))}"
)
logger.info(
f"Model parameters: {OmegaConf.to_yaml(ast.literal_eval(hyperparams['model_params'].values[0]))}"
"\nPerformance of the best model:\n",
f"Tuning AUC: {perf['tuning_auc'][0]}\nTest AUC: {perf['test_auc'][0]}",
)

def delete_below_top_k_models(self, performance, k, model_dir):
def log_hyperparams(self, best_params_fp):
"""logger.info hyperparameters of the best model with nice formatting."""
# check if this file exists
if not best_params_fp.is_file():
raise FileNotFoundError(f"Best hyperparameters file not found at {best_params_fp}")
best_params = OmegaConf.load(best_params_fp)
# print using OmegaConf.to_yaml
logger.info(f"\nHyperparameters of the best model:\n{OmegaConf.to_yaml(best_params)}")

def delete_below_top_k_models(self, perf, k, model_dir):
"""Save only top k models from the model directory and delete all other files."""
top_k_models = performance.head(k)["model_fp"].values
top_k_models = perf.head(k)["model_fp"].values
for model_fp in Path(model_dir).iterdir():
if model_fp.is_file() and model_fp.suffix != ".log" and str(model_fp) not in top_k_models:
model_fp.unlink()
4 changes: 2 additions & 2 deletions src/MEDS_tabular_automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,11 @@ def log_to_logfile(model, cfg, output_fp):
out_fp.mkdir(parents=True, exist_ok=True)

# config as a json
config_fp = out_fp / f"{cfg.model_logging.config_log_stem}.json"
config_fp = out_fp / f"{cfg.model_logging.config_log_stem}.log"
with open(config_fp, "w") as f:
f.write(OmegaConf.to_yaml(cfg))

model_performance_fp = out_fp / f"{cfg.model_logging.performance_log_stem}.csv"
model_performance_fp = out_fp / f"{cfg.model_logging.performance_log_stem}.log"
with open(model_performance_fp, "w") as f:
f.write("model_fp,tuning_auc,test_auc\n")
f.write(f"{output_fp},{model.evaluate()},{model.evaluate(split='held_out')}\n")
Expand Down
14 changes: 14 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,17 @@ def test_integration(tmp_path):
cache_config,
"task_specific_caching",
)
stderr, stdout = run_command(
"meds-tab-model",
[
"--multirun",
f"tabularization.window_sizes={stdout_ws.strip()}",
f"tabularization.aggs={stdout_agg.strip()}",
"hydra.sweeper.n_jobs=5",
"hydra.sweeper.n_trials=10",
],
cache_config,
"xgboost-model",
)
assert "The best model can be found at" in stderr
assert "Performance of the best model:" in stderr
6 changes: 5 additions & 1 deletion tests/test_tabularize.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,11 @@ def test_tabularize(tmp_path):
HydraConfig().set_config(cfg)
launch_model.main(cfg)
output_files = list(output_dir.glob("**/*.json"))
assert len(output_files) == 2
assert len(output_files) == 1

log_dir = Path(cfg.model_logging.model_log_dir)
log_csv = list(log_dir.glob("**/*.log"))
assert len(log_csv) == 2

sklearnmodel_config_kwargs = {
**shared_config,
Expand Down

0 comments on commit 0e985ee

Please sign in to comment.