Skip to content

Commit

Permalink
Merge pull request #39 from ENSTA-U2IS/dev
Browse files Browse the repository at this point in the history
📖 Improve the scaling tutorial & misc
  • Loading branch information
o-laurent authored Aug 14, 2023
2 parents b651797 + 60f8757 commit adeaca6
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 36 deletions.
65 changes: 40 additions & 25 deletions auto_tutorials_source/tutorial_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,53 @@
Improve Top-label Calibration with Temperature Scaling
======================================================
In this tutorial, we use torch-uncertainty to improve the calibration of the top-label predictions
to improve the reliability of the underlying neural network.
In this tutorial, we use *TorchUncertainty* to improve the calibration
of the top-label predictions
and the reliability of the underlying neural network.
We also see how to use the datamodules outside any Lightning Trainer.
We also see how to use the datamodules outside any Lightning trainers,
and how to use TorchUncertainty's models.
1. Loading the utilities
1. Loading the Utilities
~~~~~~~~~~~~~~~~~~~~~~~~
In this tutorial, we will need:
- torch to download the pretrained model
- the Calibration Error metric to compute the ECE and evaluate the top-label calibration
- torch for its objects
- the "calibration error" metric to compute the ECE and evaluate the top-label calibration
- the CIFAR-100 datamodule to handle the data
- the Temperature Scaler to improve the top-label calibration
- a ResNet 18 as starting model
- the temperature scaler to improve the top-label calibration
- a utility to download hf models easily
"""

import torch
from torchmetrics import CalibrationError

from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.models.resnet import resnet18
from torch_uncertainty.post_processing import TemperatureScaler
from torch_uncertainty.utils import load_hf

# %%
# 2. Downloading a Pre-trained Model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 2. Loading a model from TorchUncertainty's HF
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# To avoid training a model on CIFAR-100 from scratch, we will use here a model from https://github.com/chenyaofo/pytorch-cifar-models (thank you!)
# To avoid training a model on CIFAR-100 from scratch, we load a model from Hugging Face.
# This can be done in a one liner:

model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet20", pretrained=True)
# Build the model
model = resnet18(in_channels=3, num_classes=100, groups=1, style="cifar")

# Download the weights (the config is not used here)
weights, config = load_hf("resnet18_c100")

# Load the weights in the pre-built model
model.load_state_dict(weights)

#%%
# 3. Setting up the Datamodule and Dataloader
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 3. Setting up the Datamodule and Dataloaders
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# To get the dataloader from the datamodule, just call prepare_data, setup, and
# extract the first element of the test dataloader list. There are more than one
Expand All @@ -52,8 +65,8 @@
dataloader = dm.test_dataloader()[0]

#%%
# 4. Iterate on the Dataloader and compute the ECE
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 4. Iterating on the Dataloader and Computing the ECE
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We first split the original test set into a calibration set and a test set for proper evaluation.
#
Expand All @@ -73,33 +86,35 @@
# Iterate on the calibration dataloader
for sample, target in test_dataloader:
logits = model(sample)
ece.update(logits, target)
ece.update(logits.softmax(-1), target)

# Compute & print the calibration error
cal = ece.compute()

print(f"ECE before scaling - {cal*100:.3}%.")

#%%
# 5. Fit the Scaler to Improve the Calibration
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 5. Fitting the Scaler to Improve the Calibration
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The TemperatureScaler has one parameter that can be used to temper the softmax.
# We minimize the tempered cross-entropy on a calibration set that we define here as
# a subset of the test set and containing 1000 data.
# a subset of the test set and containing 1000 data. Look at the code run by TemperatureScaler
# `fit` method for more details.

# Fit the scaler on the calibration dataset
scaler = TemperatureScaler()
scaler = scaler.fit(model=model, calib_loader=cal_dataloader)

#%%
# 6. Iterate again to compute the improved ECE
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 6. Iterating Again to Compute the Improved ECE
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We create a wrapper of the original model and the scaler using torch.nn.Sequential.
# This is possible because the scaler is derived from nn.Module.
#
# Note that you will need to first reset the ECE metric to avoid mixing the scores of the previous and current iterations.
# Note that you will need to first reset the ECE metric to avoid mixing the scores of
# the previous and current iterations.

# Create the calibrated model
cal_model = torch.nn.Sequential(model, scaler)
Expand All @@ -124,11 +139,11 @@
#
# Temperature scaling is very efficient when the calibration set is representative of the test set.
# In this case, we say that the calibration and test set are drawn from the same distribution.
# However, this may not be True in real-world cases where dataset shift could happen.
# However, this may not hold true in real-world cases where dataset shift could happen.

# %%
# References
# ----------
#
# - **Expected Calibration Error:** Naeini, M. P., Cooper, G. F., & Hauskrecht, M. (2015). Obtaining Well Calibrated Probabilities Using Bayesian Binning. In `AAAI 2015 <https://arxiv.org/pdf/1411.0160.pdf>`_
# - **Temperature Scaling:** Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017). On calibration of modern neural networks. In `ICML 2017 <https://arxiv.org/pdf/1706.04599.pdf>`_
# - **Expected Calibration Error:** Naeini, M. P., Cooper, G. F., & Hauskrecht, M. (2015). Obtaining Well Calibrated Probabilities Using Bayesian Binning. In `AAAI 2015 <https://arxiv.org/pdf/1411.0160.pdf>`_.
# - **Temperature Scaling:** Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017). On calibration of modern neural networks. In `ICML 2017 <https://arxiv.org/pdf/1706.04599.pdf>`_.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "torch_uncertainty"
version = "0.1.4"
description = "A PyTorch Library for benchmarking and leveraging efficient predictive uncertainty quantification techniques."
description = "TorchUncertainty: A maintained and collaborative PyTorch Library for benchmarking and leveraging predictive uncertainty quantification techniques."
authors = [
"ENSTA U2IS <[email protected]>",
"Adrien Lafage <[email protected]>",
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
setup(
name="torch_uncertainty",
version="0.1.4",
description="A PyTorch Library for benchmarking and leveraging efficient"
"predictive uncertainty quantification techniques.",
description="TorchUncertainty: A maintained and collaborative PyTorch"
"Library for benchmarking and leveraging predictive uncertainty"
"quantification techniques.",
author="Adrien Lafage & Olivier Laurent",
author_email="[email protected]",
url="https://torch-uncertainty.github.io/",
Expand Down
3 changes: 1 addition & 2 deletions torch_uncertainty/datasets/cifar/cifar_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class CIFAR10C(VisionDataset):
References:
Benchmarking neural network robustness to common corruptions and
perturbations. Dan Hendrycks and Thomas Dietterich.
In ICLR, 2019.
perturbations. Dan Hendrycks and Thomas Dietterich. In ICLR, 2019.
"""

base_folder = "CIFAR-10-C"
Expand Down
7 changes: 4 additions & 3 deletions torch_uncertainty/datasets/mnist_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ class MNISTC(VisionDataset):
References:
Mu, Norman, and Justin Gilmer. "MNIST-C: A robustness benchmark for
computer vision." In ICMLW 2019.
computer vision." In ICMLW 2019.
License:
The dataset is released under the Creative Commons Attribution 4.0.
The dataset is released by the dataset's authors under the Creative
Commons Attribution 4.0.
Note:
This dataset does not contain severity levels. Raise an issue if you
want someone to investigate this.
want someone to investigate this.
"""

base_folder = "mnist_c"
Expand Down
3 changes: 0 additions & 3 deletions torch_uncertainty/post_processing/temperature_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ class TemperatureScaler(Scaler):
Reference:
Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. On calibration
of modern neural networks. In ICML 2017.
Note:
Inspired by `<https://github.com/gpleiss/temperature_scaling>`_
"""

def __init__(
Expand Down

0 comments on commit adeaca6

Please sign in to comment.