diff --git a/aviary/predict.py b/aviary/predict.py index f50d1e2..cce5173 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -9,6 +9,7 @@ import torch from tqdm import tqdm +from aviary.core import Normalizer from aviary.utils import get_metrics, print_walltime if TYPE_CHECKING: @@ -90,7 +91,9 @@ def make_ensemble_predictions( model = model_cls(**model_params) model.to(device) - model.load_state_dict(checkpoint["model_state"]) + # some models save the state dict under a different key + state_dict_field = "model_state" if "model_state" in checkpoint else "state_dict" + model.load_state_dict(checkpoint[state_dict_field]) with torch.no_grad(): preds = np.concatenate( @@ -111,6 +114,19 @@ def make_ensemble_predictions( else: df[pred_col] = preds + # denormalize predictions if a normalizer was used during training + if "normalizer_dict" in checkpoint: + assert task_type == "regression", "Normalization only takes place for regression." + normalizer = Normalizer.from_state_dict( + checkpoint["normalizer_dict"][target_name] + ) + mean = normalizer.mean.cpu().numpy() + std = normalizer.std.cpu().numpy() + # denorm the mean and aleatoric uncertainties separately + df[pred_col] = df[pred_col] * std + mean + if model.robust: + df[ale_col] = df[ale_col] * std + df_preds = df.filter(regex=r"_pred_\d") if len(checkpoint_paths) > 1: