Skip to content

Commit

Permalink
Removing unused function in evaluation callback.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Sep 8, 2024
1 parent 2563aaf commit 2d80905
Showing 1 changed file with 1 addition and 14 deletions.
15 changes: 1 addition & 14 deletions src/MEDS_tabular_automl/evaluation_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import polars as pl
from hydra.experimental.callback import Callback
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig


class EvaluationCallback(Callback):
def __init__(self, **kwargs):
self.kwargs = kwargs

def on_multirun_end(self, config: DictConfig, **kwargs):
"""Find best model based on log files and logger.info its performance and hyperparameters."""
log_fp = Path(config.model_logging.model_log_dir)
Expand All @@ -27,7 +24,6 @@ def on_multirun_end(self, config: DictConfig, **kwargs):

logger.info(f"The best model can be found at {best_model}")
self.log_performance(perf[0, :])
# self.log_hyperparams(log_fp / best_model / f"{config.model_logging.config_log_stem}.log")
if hasattr(config, "model_saving.delete_below_top_k") and config.delete_below_top_k >= 0:
self.delete_below_top_k_models(
perf, config.model_saving.delete_below_top_k, config.model_saving.model_dir
Expand All @@ -43,15 +39,6 @@ def log_performance(self, perf):
f"\nPerformance of best model:\nTuning AUC: {tuning_auc}\nTest AUC: {test_auc}",
)

def log_hyperparams(self, best_params_fp):
"""logger.info hyperparameters of the best model with nice formatting."""
# check if this file exists
if not best_params_fp.is_file():
raise FileNotFoundError(f"Best hyperparameters file not found at {best_params_fp}")
best_params = OmegaConf.load(best_params_fp)
# print using OmegaConf.to_yaml
logger.info(f"\nHyperparameters of the best model:\n{OmegaConf.to_yaml(best_params)}")

def delete_below_top_k_models(self, perf, k, model_dir):
"""Save only top k models from the model directory and delete all other files."""
top_k_models = perf.head(k)["model_fp"].values
Expand Down

0 comments on commit 2d80905

Please sign in to comment.