diff --git a/src/MEDS_tabular_automl/evaluation_callback.py b/src/MEDS_tabular_automl/evaluation_callback.py index d9236f8..30618db 100644 --- a/src/MEDS_tabular_automl/evaluation_callback.py +++ b/src/MEDS_tabular_automl/evaluation_callback.py @@ -15,30 +15,30 @@ def on_multirun_end(self, config: DictConfig, **kwargs): log_fp = Path(config.model_logging.model_log_dir) try: - perf = pl.read_csv(log_fp / f"*/*{config.model_logging.performance_log_stem}.log") + performance = pl.read_csv(log_fp / f"*/*{config.model_logging.performance_log_stem}.log") except Exception as e: raise FileNotFoundError(f"Log files incomplete or not found at {log_fp}, exception {e}.") - perf = perf.sort("tuning_auc", descending=True, nulls_last=True) - logger.info(f"\nPerformance of the top 10 models:\n{perf.head(10)}") + performance = performance.sort("tuning_auc", descending=True, nulls_last=True) + logger.info(f"\nPerformance of the top 10 models:\n{performance.head(10)}") # get best model_fp - best_model = perf[0, 0] + best_model = performance[0, 0] logger.info(f"The best model can be found at {best_model}") - self.log_performance(perf[0, :]) + self.log_performance(performance[0, :]) # self.log_hyperparams(log_fp / best_model / f"{config.model_logging.config_log_stem}.log") if hasattr(config, "model_saving.delete_below_top_k") and config.delete_below_top_k >= 0: self.delete_below_top_k_models( - perf, config.model_saving.delete_below_top_k, config.model_saving.model_dir + performance, config.model_saving.delete_below_top_k, config.model_saving.model_dir ) - return perf.head(1) + return performance.head(1) - def log_performance(self, perf): + def log_performance(self, best_model_performance): """logger.info performance of the best model with nice formatting.""" - tuning_auc = perf["tuning_auc"][0] - test_auc = perf["test_auc"][0] + tuning_auc = best_model_performance["tuning_auc"][0] + test_auc = best_model_performance["test_auc"][0] logger.info( f"\nPerformance of best model:\nTuning AUC: {tuning_auc}\nTest AUC: {test_auc}", ) @@ -52,9 +52,9 @@ def log_hyperparams(self, best_params_fp): # print using OmegaConf.to_yaml logger.info(f"\nHyperparameters of the best model:\n{OmegaConf.to_yaml(best_params)}") - def delete_below_top_k_models(self, perf, k, model_dir): + def delete_below_top_k_models(self, performance, k, model_dir): """Save only top k models from the model directory and delete all other files.""" - top_k_models = perf.head(k)["model_fp"].values + top_k_models = performance.head(k)["model_fp"].values for model_fp in Path(model_dir).iterdir(): if model_fp.is_file() and model_fp.suffix != ".log" and str(model_fp) not in top_k_models: model_fp.unlink()