From 90e47c97f36ee02a143aeb5be55f0fb82d4458a2 Mon Sep 17 00:00:00 2001 From: Karan Bania <101618474+karannb@users.noreply.github.com> Date: Sun, 10 Nov 2024 20:52:45 +0530 Subject: [PATCH 1/8] Added normalization for predictions. --- aviary/predict.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/aviary/predict.py b/aviary/predict.py index f50d1e2..fbe0ad8 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -15,7 +15,7 @@ import wandb.apis.public from torch.utils.data import DataLoader - from aviary.core import BaseModelClass + from aviary.core import BaseModelClass, Normalizer from aviary.data import InMemoryDataLoader __author__ = "Janosh Riebesell" @@ -90,13 +90,20 @@ def make_ensemble_predictions( model = model_cls(**model_params) model.to(device) - model.load_state_dict(checkpoint["model_state"]) + # some models save the state dict under a different key + state_dict_field = "model_state" if "model_state" in checkpoint else "state_dict" + model.load_state_dict(checkpoint[state_dict_field]) with torch.no_grad(): preds = np.concatenate( [model(*inputs)[0].cpu().numpy() for inputs, *_ in data_loader] ).squeeze() + # denormalize predictions if a normalizer was used during training + if "normalizer_dict" in checkpoint: + normalizer = Normalizer.from_state_dict(checkpoint["normalizer_dict"][target_name]) + preds = normalizer.denorm(preds) + pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}" if model.robust: From b1bcbf01312e9536593c674636d7536ca364e122 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 10 Nov 2024 15:25:43 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aviary/predict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aviary/predict.py b/aviary/predict.py index fbe0ad8..1b90457 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -101,7 +101,9 @@ def make_ensemble_predictions( # denormalize predictions if a normalizer was used during training if "normalizer_dict" in checkpoint: - normalizer = Normalizer.from_state_dict(checkpoint["normalizer_dict"][target_name]) + normalizer = Normalizer.from_state_dict( + checkpoint["normalizer_dict"][target_name] + ) preds = normalizer.denorm(preds) pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}" From ae60106bde8155f5799f4a776e2e268b943183d9 Mon Sep 17 00:00:00 2001 From: Karan Bania <101618474+karannb@users.noreply.github.com> Date: Sun, 10 Nov 2024 21:01:31 +0530 Subject: [PATCH 3/8] Fixed placement of Normalizer import to outside `TYPE_CHECKING` block. --- aviary/predict.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aviary/predict.py b/aviary/predict.py index 1b90457..acbb29b 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -9,13 +9,14 @@ import torch from tqdm import tqdm +from aviary.core import Normalizer from aviary.utils import get_metrics, print_walltime if TYPE_CHECKING: import wandb.apis.public from torch.utils.data import DataLoader - from aviary.core import BaseModelClass, Normalizer + from aviary.core import BaseModelClass from aviary.data import InMemoryDataLoader __author__ = "Janosh Riebesell" From b324ae9ae7ecdd9c436303884cecd5ca1e1c0e03 Mon Sep 17 00:00:00 2001 From: Karan Bania <101618474+karannb@users.noreply.github.com> Date: Mon, 11 Nov 2024 06:58:40 +0530 Subject: [PATCH 4/8] denorming aleatoric uncertainties separately. --- aviary/predict.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/aviary/predict.py b/aviary/predict.py index acbb29b..3e15696 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -102,10 +102,18 @@ def make_ensemble_predictions( # denormalize predictions if a normalizer was used during training if "normalizer_dict" in checkpoint: + assert task_type == "regression", "Normalization only takes place for regression." normalizer = Normalizer.from_state_dict( checkpoint["normalizer_dict"][target_name] ) - preds = normalizer.denorm(preds) + if model.robust: + # denorm the mean and aleatoroc uncertainties separately + mean, log_std = np.split(preds, 2, axis=1) + preds = normalizer.denorm(mean) + ale_std = np.exp(log_std) * normalizer.std + preds = np.column_stack([preds, ale_std]) + else: + preds = normalizer.denorm(preds) pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}" From b20a50e2b5f3a075b2ad8bfd787ee3c68e75425a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Nov 2024 01:28:59 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aviary/predict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aviary/predict.py b/aviary/predict.py index 3e15696..1eb0497 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -102,7 +102,9 @@ def make_ensemble_predictions( # denormalize predictions if a normalizer was used during training if "normalizer_dict" in checkpoint: - assert task_type == "regression", "Normalization only takes place for regression." + assert ( + task_type == "regression" + ), "Normalization only takes place for regression." normalizer = Normalizer.from_state_dict( checkpoint["normalizer_dict"][target_name] ) From 97416e0a37472b5ca755b6d526b86c4c4674491b Mon Sep 17 00:00:00 2001 From: Karan Bania <101618474+karannb@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:16:34 +0530 Subject: [PATCH 6/8] moved normaliztion outside the prediction loop. --- aviary/predict.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/aviary/predict.py b/aviary/predict.py index 1eb0497..3915e35 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -100,23 +100,6 @@ def make_ensemble_predictions( [model(*inputs)[0].cpu().numpy() for inputs, *_ in data_loader] ).squeeze() - # denormalize predictions if a normalizer was used during training - if "normalizer_dict" in checkpoint: - assert ( - task_type == "regression" - ), "Normalization only takes place for regression." - normalizer = Normalizer.from_state_dict( - checkpoint["normalizer_dict"][target_name] - ) - if model.robust: - # denorm the mean and aleatoroc uncertainties separately - mean, log_std = np.split(preds, 2, axis=1) - preds = normalizer.denorm(mean) - ale_std = np.exp(log_std) * normalizer.std - preds = np.column_stack([preds, ale_std]) - else: - preds = normalizer.denorm(preds) - pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}" if model.robust: @@ -131,6 +114,19 @@ def make_ensemble_predictions( else: df[pred_col] = preds + # denormalize predictions if a normalizer was used during training + if "normalizer_dict" in checkpoint: + assert ( + task_type == "regression" + ), "Normalization only takes place for regression." + normalizer = Normalizer.from_state_dict( + checkpoint["normalizer_dict"][target_name] + ) + # denorm the mean and aleatoric uncertainties separately + df[pred_col] = df[pred_col].apply(normalizer.denorm) + if model.robust: + df[ale_col] = df[ale_col] * normalizer.std + df_preds = df.filter(regex=r"_pred_\d") if len(checkpoint_paths) > 1: From b3ca19992046e0330263536f9d8690a780ecbcbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 04:47:04 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aviary/predict.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aviary/predict.py b/aviary/predict.py index 3915e35..1b1880a 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -116,9 +116,7 @@ def make_ensemble_predictions( # denormalize predictions if a normalizer was used during training if "normalizer_dict" in checkpoint: - assert ( - task_type == "regression" - ), "Normalization only takes place for regression." + assert task_type == "regression", "Normalization only takes place for regression." normalizer = Normalizer.from_state_dict( checkpoint["normalizer_dict"][target_name] ) From 3a72e04a90af65f6ed5c8577cb7cabfd3eb5a77a Mon Sep 17 00:00:00 2001 From: Karan Bania <101618474+karannb@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:21:50 +0530 Subject: [PATCH 8/8] Fixed type errors. --- aviary/predict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aviary/predict.py b/aviary/predict.py index 1b1880a..cce5173 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -120,10 +120,12 @@ def make_ensemble_predictions( normalizer = Normalizer.from_state_dict( checkpoint["normalizer_dict"][target_name] ) + mean = normalizer.mean.cpu().numpy() + std = normalizer.std.cpu().numpy() # denorm the mean and aleatoric uncertainties separately - df[pred_col] = df[pred_col].apply(normalizer.denorm) + df[pred_col] = df[pred_col] * std + mean if model.robust: - df[ale_col] = df[ale_col] * normalizer.std + df[ale_col] = df[ale_col] * std df_preds = df.filter(regex=r"_pred_\d")