Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified API #8

Merged
merged 8 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 28 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,54 +1,43 @@
# uqlib


General purpose python library for **U**ncertainy **Q**uantification with [`torch`](https://github.com/pytorch/pytorch).

General purpose python library for **U**ncertainy **Q**uantification (methods and benchmarks) with [PyTorch](https://github.com/pytorch/pytorch) models.
`uqlib` is functional first and aims to be easy to use and extend. Iterative `uqlib` algorithms take the following unified form
```python
state = transform.init(dict(model.named_parameters()))

- All methods should be linear in the number of parameters, and therefore able to handle large models (e.g. transformers).
- We should support uncertainty quantification over subsets of parameters.
- We should support arbitrary loss functions.
- We should support uncertainty over some subset of the parameters - *this will take some thinking about*.
- Bayesian methods should support arbitrary priors (we just need pointwise evaluations).
for batch in dataloader:
state = transform.update(state, batch)
```

Here `transform` is an algorithm kernel that is pre-built with all the necessary configuration arguments. For example:
```python
num_data = len(dataloader.dataset)
functional_model = uqlib.model_to_function(model)
log_posterior = lambda p, b: -loss_fn(functional_model(p, b), b) + prior(p) / num_data
optimizer = partial(torchopt.Adam, lr=1e-3)
transform = uqlib.vi.diag.build(log_posterior, optimizer, temperature=1/num_data)
```

## Friends

Should interface well with

- Existing optimisers in [torch.optim](https://pytorch.org/docs/stable/optim.html) (we do not need to provide gradient descent)
- [transformers](https://github.com/huggingface/transformers) for fine-tuning pre-trained models (we should make sure our uncertainty methods are also compatible in terms of inference/generation)
- [PyTorch Lightning](https://github.com/Lightning-AI/lightning) for convenient training and logging
Observe that `uqlib` recommends specifying `log_posterior` and `temperature` such that
`log_posterior` remains on the same scale for different batch sizes. `uqlib`
algorithms are designed to be stable as `temperature` goes to zero.


## Methods

- [ ] [Dropout](https://arxiv.org/abs/1506.02142)
- [ ] [Variational inference (mean-field and KFAC)](https://arxiv.org/abs/1601.00670)
- Basic/naive NELBO added but this should be upgraded (to be optimised + KFAC)
and tested.
- [ ] [Laplace approximation (mean-field and KFAC)](https://arxiv.org/abs/2106.14806)
- Currently we have a basic Hessian diagonal implementation but this should be
replaced with diagonal (and KFAC) Fisher information which is guaranteed to be positive definite.
- [ ] [Deep Ensemble](https://arxiv.org/abs/1612.01474)
- [ ] [SGMCMC](https://arxiv.org/abs/1506.04696)
- v0 implementation added but needs API finalising and tests on e.g. linear
Gaussian models with known posterior mean + cov.
- [ ] Ensemble SGMCMC
- [ ] [SNGP](https://arxiv.org/abs/2006.10108)
- [ ] [Epistemic neural networks](https://arxiv.org/abs/2107.08924)
<!-- - [ ] [Conformal prediction](https://arxiv.org/abs/2107.07511) -->

## Friends

## Benchmarks
Interfaces seamlessly with:

Benchmarks should extend beyond those in [uncertainty-baselines](https://github.com/google/uncertainty-baselines). We can include classification and regression as toy examples but the leaderboard should consist of the following more practically relevant tasks:
- [`torch`](https://github.com/pytorch/pytorch) and in particular [`torch.func`](https://pytorch.org/docs/stable/func.html).
- [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) for distributions and sampling, (note that it's typically required to set `validate_args=False` to conform with the control flows in [`torch.func`](https://pytorch.org/docs/stable/func.html)).
- Functional and flexible torch optimizers from [`torchopt`](https://github.com/metaopt/torchopt),
(which is the default for [`uqlib.vi`](uqlib/vi/) but `torch.optim` also interfaces easily).
- [`transformers`](https://github.com/huggingface/transformers) for pre-trained models.
- [`lightning`](https://github.com/Lightning-AI/lightning) for convenient training and logging, see [examples/lightning_autoencoder.py](examples/lightning_autoencoder.py).

- [ ] Generation
- Aleatoric vs epistemic uncertainty (e.g. hallucination detection)
- [ ] Continual learning
- Regression/classification/generation tasks but with a stream of data. Evaluate perfomance on current and historical data/tasks.
- [ ] Decision making
- Thompson sampling effectiveness
The functional transform interface is strongly inspired by frameworks such as
[`optax`](https://github.com/google-deepmind/optax) and [`BlackJAX`](https://github.com/blackjax-devs/blackjax).


## Contributing
Expand Down
80 changes: 80 additions & 0 deletions examples/lightning_autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
from torch import nn, utils
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import torchopt

import uqlib

# Example from https://lightning.ai/docs/pytorch/stable/starter/introduction.html

method, config_args = uqlib.vi.diag, {"optimizer": torchopt.adam(lr=1e-3)}
# method, config_args = uqlib.sgmcmc.sghmc, {"lr": 1e-3}

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

encoder_function = uqlib.model_to_function(encoder)
decoder_function = uqlib.model_to_function(decoder)


def log_posterior(params, batch):
x, y = batch
x = x.view(x.size(0), -1)
z = encoder_function(params[0], x)
x_hat = decoder_function(params[1], z)
return torch.distributions.Normal(x_hat, 1, validate_args=False).log_prob(x).sum()


# define the LightningModule
class LitAutoEncoderUQ(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
self.state = self.transform.update(self.state, batch, inplace=True)
# Logging to TensorBoard (if installed) by default
for k, v in self.state._asdict().items():
if isinstance(v, float):
self.log(k, v)

def configure_optimizers(self):
self.transform = method.build(log_posterior, **config_args)
all_params = [
dict(self.encoder.named_parameters()),
dict(self.decoder.named_parameters()),
]
self.state = self.transform.init(all_params)

def on_save_checkpoint(self, checkpoint):
checkpoint["state"] = self.state

def on_load_checkpoint(self, checkpoint):
self.state = checkpoint["state"]


autoencoderuq = LitAutoEncoderUQ(encoder, decoder)


# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoderuq, train_dataloaders=train_loader)


checkpoint = "./lightning_logs/version_3/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoderUQ.load_from_checkpoint(
checkpoint, encoder=encoder, decoder=decoder
)


assert hasattr(autoencoder, "state")
45 changes: 13 additions & 32 deletions examples/yelp/yelp_subspace_laplace_diag_fisher.ipynb

Large diffs are not rendered by default.

114 changes: 42 additions & 72 deletions examples/yelp/yelp_subspace_sghmc.ipynb

Large diffs are not rendered by default.

28 changes: 13 additions & 15 deletions examples/yelp/yelp_subspace_vi_diag.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion experiments/load_sghmc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sghmc.modules.classifier import Classifier
from uqlib import load_optimizer_param_to_model
import torch

from utils.utils import load_optimizer_param_to_model

ckpt_path = None

model = Classifier()
Expand Down
22 changes: 22 additions & 0 deletions experiments/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from typing import List
import torch
from torch import nn


def parse_devices(devices):
devices = devices.split(",")
devices_list = []
Expand All @@ -8,3 +13,20 @@ def parse_devices(devices):
pass
devices_list.append(device)
return devices_list


def load_optimizer_param_to_model(model: nn.Module, groups: List[List[torch.Tensor]]):
"""Updates the model parameters in-place with the provided grouped parameters.

Args:
model: A torch.nn.Module object
groups: A list of groups where each group is a list of parameters
"""

SamDuffield marked this conversation as resolved.
Show resolved Hide resolved
optimizer_params = []
for group in groups:
for param in group:
optimizer_params.append(torch.from_numpy(param))

for model_param, optimizer_param in zip(list(model.parameters()), optimizer_params):
model_param.data = optimizer_param
44 changes: 37 additions & 7 deletions tests/laplace/test_diag_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ def log_posterior(p, b):
return log_posterior_n(p, b, model, len(xs)).mean()

params = dict(model.named_parameters())
laplace_state = diag_fisher.init(params)

transform = diag_fisher.build(log_posterior)
laplace_state = transform.init(params)
for batch in dataloader:
laplace_state = diag_fisher.update(laplace_state, log_posterior, batch)
laplace_state = transform.update(laplace_state, batch)

expected = tree_map(lambda x: torch.zeros_like(x), params)
for x, y in zip(xs, ys):
Expand All @@ -71,21 +73,49 @@ def log_posterior(p, b):
assert torch.allclose(expected[key], laplace_state.prec_diag[key], atol=1e-5)

# Also check full batch
laplace_state_fb = diag_fisher.init(params)
laplace_state_fb = diag_fisher.update(laplace_state_fb, log_posterior, (xs, ys))
laplace_state_fb = transform.init(params)
laplace_state_fb = transform.update(laplace_state_fb, (xs, ys))

for key in expected:
assert torch.allclose(expected[key], laplace_state_fb.prec_diag[key], atol=1e-5)

# Test per_sample
log_posterior_per_sample = partial(log_posterior_n, model=model, n_data=len(xs))
laplace_state_ps = diag_fisher.init(params)
transform_ps = diag_fisher.build(log_posterior_per_sample, per_sample=True)
laplace_state_ps = transform_ps.init(params)
for batch in dataloader:
laplace_state_ps = diag_fisher.update(
laplace_state_ps, log_posterior_per_sample, batch, per_sample=True
laplace_state_ps = transform_ps.update(
laplace_state_ps,
batch,
)

for key in expected:
assert torch.allclose(
laplace_state_ps.prec_diag[key], laplace_state_fb.prec_diag[key], atol=1e-5
)

# Test inplace
laplace_state_ip = transform.init(params)
laplace_state_ip2 = transform.update(
laplace_state_ip,
batch,
inplace=True,
)

for key in expected:
assert torch.allclose(
laplace_state_ip2.prec_diag[key], laplace_state_ip.prec_diag[key], atol=1e-8
)

# Test not inplace
laplace_state_ip_false = transform.update(
laplace_state_ip,
batch,
inplace=False,
)
for key in expected:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't get what this little bit is doing down here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to test that inplace=True indeed modifies the input state in-place, whilst inplace=False leaves the input state untouched.

Definitely this could be improved, although perhaps in a later PR

assert not torch.allclose(
laplace_state_ip_false.prec_diag[key],
laplace_state_ip.prec_diag[key],
atol=1e-8,
)
36 changes: 32 additions & 4 deletions tests/laplace/test_diag_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def test_diag_hessian():
log_posterior = partial(log_posterior_n, model=model, n_data=len(xs))

params = dict(model.named_parameters())
laplace_state = diag_hessian.init(params)

transform = diag_hessian.build(log_posterior)
laplace_state = transform.init(params)
for batch in dataloader:
laplace_state = diag_hessian.update(laplace_state, log_posterior, batch)
laplace_state = transform.update(laplace_state, batch)

expected = tree_map(lambda x: torch.zeros_like(x), params)
for x, y in zip(xs, ys):
Expand All @@ -66,8 +68,34 @@ def test_diag_hessian():
assert torch.allclose(expected[key], laplace_state.prec_diag[key])

# Also check full batch
laplace_state_fb = diag_hessian.init(params)
laplace_state_fb = diag_hessian.update(laplace_state_fb, log_posterior, (xs, ys))
laplace_state_fb = transform.init(params)
laplace_state_fb = transform.update(laplace_state_fb, (xs, ys))

for key in expected:
assert torch.allclose(expected[key], laplace_state_fb.prec_diag[key])

# Test inplace
laplace_state_ip = transform.init(params)
laplace_state_ip2 = transform.update(
laplace_state_ip,
batch,
inplace=True,
)

for key in expected:
assert torch.allclose(
laplace_state_ip2.prec_diag[key], laplace_state_ip.prec_diag[key], atol=1e-8
)

# Test not inplace
laplace_state_ip_false = transform.update(
laplace_state_ip,
batch,
inplace=False,
)
for key in expected:
assert not torch.allclose(
laplace_state_ip_false.prec_diag[key],
laplace_state_ip.prec_diag[key],
atol=1e-8,
)
64 changes: 0 additions & 64 deletions tests/sgmcmc/optim/test_SGHMC.py

This file was deleted.

Loading
Loading