Skip to content

Commit

Permalink
Merge pull request #8 from normal-computing/unified-api
Browse files Browse the repository at this point in the history
Unified API
  • Loading branch information
SamDuffield authored Jan 30, 2024
2 parents ac07f37 + 431d054 commit a08c55d
Show file tree
Hide file tree
Showing 23 changed files with 623 additions and 465 deletions.
67 changes: 28 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,54 +1,43 @@
# uqlib


General purpose python library for **U**ncertainy **Q**uantification with [`torch`](https://github.com/pytorch/pytorch).

General purpose python library for **U**ncertainy **Q**uantification (methods and benchmarks) with [PyTorch](https://github.com/pytorch/pytorch) models.
`uqlib` is functional first and aims to be easy to use and extend. Iterative `uqlib` algorithms take the following unified form
```python
state = transform.init(dict(model.named_parameters()))

- All methods should be linear in the number of parameters, and therefore able to handle large models (e.g. transformers).
- We should support uncertainty quantification over subsets of parameters.
- We should support arbitrary loss functions.
- We should support uncertainty over some subset of the parameters - *this will take some thinking about*.
- Bayesian methods should support arbitrary priors (we just need pointwise evaluations).
for batch in dataloader:
state = transform.update(state, batch)
```

Here `transform` is an algorithm kernel that is pre-built with all the necessary configuration arguments. For example:
```python
num_data = len(dataloader.dataset)
functional_model = uqlib.model_to_function(model)
log_posterior = lambda p, b: -loss_fn(functional_model(p, b), b) + prior(p) / num_data
optimizer = partial(torchopt.Adam, lr=1e-3)
transform = uqlib.vi.diag.build(log_posterior, optimizer, temperature=1/num_data)
```

## Friends

Should interface well with

- Existing optimisers in [torch.optim](https://pytorch.org/docs/stable/optim.html) (we do not need to provide gradient descent)
- [transformers](https://github.com/huggingface/transformers) for fine-tuning pre-trained models (we should make sure our uncertainty methods are also compatible in terms of inference/generation)
- [PyTorch Lightning](https://github.com/Lightning-AI/lightning) for convenient training and logging
Observe that `uqlib` recommends specifying `log_posterior` and `temperature` such that
`log_posterior` remains on the same scale for different batch sizes. `uqlib`
algorithms are designed to be stable as `temperature` goes to zero.


## Methods

- [ ] [Dropout](https://arxiv.org/abs/1506.02142)
- [ ] [Variational inference (mean-field and KFAC)](https://arxiv.org/abs/1601.00670)
- Basic/naive NELBO added but this should be upgraded (to be optimised + KFAC)
and tested.
- [ ] [Laplace approximation (mean-field and KFAC)](https://arxiv.org/abs/2106.14806)
- Currently we have a basic Hessian diagonal implementation but this should be
replaced with diagonal (and KFAC) Fisher information which is guaranteed to be positive definite.
- [ ] [Deep Ensemble](https://arxiv.org/abs/1612.01474)
- [ ] [SGMCMC](https://arxiv.org/abs/1506.04696)
- v0 implementation added but needs API finalising and tests on e.g. linear
Gaussian models with known posterior mean + cov.
- [ ] Ensemble SGMCMC
- [ ] [SNGP](https://arxiv.org/abs/2006.10108)
- [ ] [Epistemic neural networks](https://arxiv.org/abs/2107.08924)
<!-- - [ ] [Conformal prediction](https://arxiv.org/abs/2107.07511) -->

## Friends

## Benchmarks
Interfaces seamlessly with:

Benchmarks should extend beyond those in [uncertainty-baselines](https://github.com/google/uncertainty-baselines). We can include classification and regression as toy examples but the leaderboard should consist of the following more practically relevant tasks:
- [`torch`](https://github.com/pytorch/pytorch) and in particular [`torch.func`](https://pytorch.org/docs/stable/func.html).
- [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) for distributions and sampling, (note that it's typically required to set `validate_args=False` to conform with the control flows in [`torch.func`](https://pytorch.org/docs/stable/func.html)).
- Functional and flexible torch optimizers from [`torchopt`](https://github.com/metaopt/torchopt),
(which is the default for [`uqlib.vi`](uqlib/vi/) but `torch.optim` also interfaces easily).
- [`transformers`](https://github.com/huggingface/transformers) for pre-trained models.
- [`lightning`](https://github.com/Lightning-AI/lightning) for convenient training and logging, see [examples/lightning_autoencoder.py](examples/lightning_autoencoder.py).

- [ ] Generation
- Aleatoric vs epistemic uncertainty (e.g. hallucination detection)
- [ ] Continual learning
- Regression/classification/generation tasks but with a stream of data. Evaluate perfomance on current and historical data/tasks.
- [ ] Decision making
- Thompson sampling effectiveness
The functional transform interface is strongly inspired by frameworks such as
[`optax`](https://github.com/google-deepmind/optax) and [`BlackJAX`](https://github.com/blackjax-devs/blackjax).


## Contributing
Expand Down
80 changes: 80 additions & 0 deletions examples/lightning_autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
from torch import nn, utils
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import torchopt

import uqlib

# Example from https://lightning.ai/docs/pytorch/stable/starter/introduction.html

method, config_args = uqlib.vi.diag, {"optimizer": torchopt.adam(lr=1e-3)}
# method, config_args = uqlib.sgmcmc.sghmc, {"lr": 1e-3}

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

encoder_function = uqlib.model_to_function(encoder)
decoder_function = uqlib.model_to_function(decoder)


def log_posterior(params, batch):
x, y = batch
x = x.view(x.size(0), -1)
z = encoder_function(params[0], x)
x_hat = decoder_function(params[1], z)
return torch.distributions.Normal(x_hat, 1, validate_args=False).log_prob(x).sum()


# define the LightningModule
class LitAutoEncoderUQ(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
self.state = self.transform.update(self.state, batch, inplace=True)
# Logging to TensorBoard (if installed) by default
for k, v in self.state._asdict().items():
if isinstance(v, float):
self.log(k, v)

def configure_optimizers(self):
self.transform = method.build(log_posterior, **config_args)
all_params = [
dict(self.encoder.named_parameters()),
dict(self.decoder.named_parameters()),
]
self.state = self.transform.init(all_params)

def on_save_checkpoint(self, checkpoint):
checkpoint["state"] = self.state

def on_load_checkpoint(self, checkpoint):
self.state = checkpoint["state"]


autoencoderuq = LitAutoEncoderUQ(encoder, decoder)


# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoderuq, train_dataloaders=train_loader)


checkpoint = "./lightning_logs/version_3/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoderUQ.load_from_checkpoint(
checkpoint, encoder=encoder, decoder=decoder
)


assert hasattr(autoencoder, "state")
45 changes: 13 additions & 32 deletions examples/yelp/yelp_subspace_laplace_diag_fisher.ipynb

Large diffs are not rendered by default.

114 changes: 42 additions & 72 deletions examples/yelp/yelp_subspace_sghmc.ipynb

Large diffs are not rendered by default.

28 changes: 13 additions & 15 deletions examples/yelp/yelp_subspace_vi_diag.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion experiments/load_sghmc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sghmc.modules.classifier import Classifier
from uqlib import load_optimizer_param_to_model
import torch

from utils.utils import load_optimizer_param_to_model

ckpt_path = None

model = Classifier()
Expand Down
22 changes: 22 additions & 0 deletions experiments/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from typing import List
import torch
from torch import nn


def parse_devices(devices):
devices = devices.split(",")
devices_list = []
Expand All @@ -8,3 +13,20 @@ def parse_devices(devices):
pass
devices_list.append(device)
return devices_list


def load_optimizer_param_to_model(model: nn.Module, groups: List[List[torch.Tensor]]):
"""Updates the model parameters in-place with the provided grouped parameters.
Args:
model: A torch.nn.Module object
groups: A list of groups where each group is a list of parameters
"""

optimizer_params = []
for group in groups:
for param in group:
optimizer_params.append(torch.from_numpy(param))

for model_param, optimizer_param in zip(list(model.parameters()), optimizer_params):
model_param.data = optimizer_param
44 changes: 37 additions & 7 deletions tests/laplace/test_diag_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ def log_posterior(p, b):
return log_posterior_n(p, b, model, len(xs)).mean()

params = dict(model.named_parameters())
laplace_state = diag_fisher.init(params)

transform = diag_fisher.build(log_posterior)
laplace_state = transform.init(params)
for batch in dataloader:
laplace_state = diag_fisher.update(laplace_state, log_posterior, batch)
laplace_state = transform.update(laplace_state, batch)

expected = tree_map(lambda x: torch.zeros_like(x), params)
for x, y in zip(xs, ys):
Expand All @@ -71,21 +73,49 @@ def log_posterior(p, b):
assert torch.allclose(expected[key], laplace_state.prec_diag[key], atol=1e-5)

# Also check full batch
laplace_state_fb = diag_fisher.init(params)
laplace_state_fb = diag_fisher.update(laplace_state_fb, log_posterior, (xs, ys))
laplace_state_fb = transform.init(params)
laplace_state_fb = transform.update(laplace_state_fb, (xs, ys))

for key in expected:
assert torch.allclose(expected[key], laplace_state_fb.prec_diag[key], atol=1e-5)

# Test per_sample
log_posterior_per_sample = partial(log_posterior_n, model=model, n_data=len(xs))
laplace_state_ps = diag_fisher.init(params)
transform_ps = diag_fisher.build(log_posterior_per_sample, per_sample=True)
laplace_state_ps = transform_ps.init(params)
for batch in dataloader:
laplace_state_ps = diag_fisher.update(
laplace_state_ps, log_posterior_per_sample, batch, per_sample=True
laplace_state_ps = transform_ps.update(
laplace_state_ps,
batch,
)

for key in expected:
assert torch.allclose(
laplace_state_ps.prec_diag[key], laplace_state_fb.prec_diag[key], atol=1e-5
)

# Test inplace
laplace_state_ip = transform.init(params)
laplace_state_ip2 = transform.update(
laplace_state_ip,
batch,
inplace=True,
)

for key in expected:
assert torch.allclose(
laplace_state_ip2.prec_diag[key], laplace_state_ip.prec_diag[key], atol=1e-8
)

# Test not inplace
laplace_state_ip_false = transform.update(
laplace_state_ip,
batch,
inplace=False,
)
for key in expected:
assert not torch.allclose(
laplace_state_ip_false.prec_diag[key],
laplace_state_ip.prec_diag[key],
atol=1e-8,
)
36 changes: 32 additions & 4 deletions tests/laplace/test_diag_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def test_diag_hessian():
log_posterior = partial(log_posterior_n, model=model, n_data=len(xs))

params = dict(model.named_parameters())
laplace_state = diag_hessian.init(params)

transform = diag_hessian.build(log_posterior)
laplace_state = transform.init(params)
for batch in dataloader:
laplace_state = diag_hessian.update(laplace_state, log_posterior, batch)
laplace_state = transform.update(laplace_state, batch)

expected = tree_map(lambda x: torch.zeros_like(x), params)
for x, y in zip(xs, ys):
Expand All @@ -66,8 +68,34 @@ def test_diag_hessian():
assert torch.allclose(expected[key], laplace_state.prec_diag[key])

# Also check full batch
laplace_state_fb = diag_hessian.init(params)
laplace_state_fb = diag_hessian.update(laplace_state_fb, log_posterior, (xs, ys))
laplace_state_fb = transform.init(params)
laplace_state_fb = transform.update(laplace_state_fb, (xs, ys))

for key in expected:
assert torch.allclose(expected[key], laplace_state_fb.prec_diag[key])

# Test inplace
laplace_state_ip = transform.init(params)
laplace_state_ip2 = transform.update(
laplace_state_ip,
batch,
inplace=True,
)

for key in expected:
assert torch.allclose(
laplace_state_ip2.prec_diag[key], laplace_state_ip.prec_diag[key], atol=1e-8
)

# Test not inplace
laplace_state_ip_false = transform.update(
laplace_state_ip,
batch,
inplace=False,
)
for key in expected:
assert not torch.allclose(
laplace_state_ip_false.prec_diag[key],
laplace_state_ip.prec_diag[key],
atol=1e-8,
)
64 changes: 0 additions & 64 deletions tests/sgmcmc/optim/test_SGHMC.py

This file was deleted.

Loading

0 comments on commit a08c55d

Please sign in to comment.