Skip to content

Commit

Permalink
Merge pull request #14 from normal-computing/add-opt
Browse files Browse the repository at this point in the history
Add optimizers to unified API
  • Loading branch information
SamDuffield authored Feb 28, 2024
2 parents ad12454 + dd077ae commit cb47f7d
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch

import uqlib


def test_optim():
optimizer_cls = torch.optim.SGD
lr = 0.1

def loss_fn(p, b):
return torch.sum(p**2), torch.tensor([])

transform = uqlib.optim.build(loss_fn, optimizer_cls, lr=lr)

params = torch.tensor([1.0], requires_grad=True)
state = transform.init(params)

for _ in range(100):
state = transform.update(state, torch.tensor([1.0]))

assert state.loss < 1e-3
assert state.params < 1e-3
22 changes: 22 additions & 0 deletions tests/test_torchopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import torchopt

import uqlib


def test_torchopt():
optimizer = torchopt.sgd(lr=0.1)

def loss_fn(p, b):
return torch.sum(p**2), torch.tensor([])

transform = uqlib.torchopt.build(loss_fn, optimizer)

params = torch.tensor([1.0], requires_grad=True)
state = transform.init(params)

for _ in range(100):
state = transform.update(state, torch.tensor([1.0]))

assert state.loss < 1e-3
assert state.params < 1e-3
2 changes: 2 additions & 0 deletions uqlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from uqlib import sgmcmc
from uqlib import types
from uqlib import vi
from uqlib import optim
from uqlib import torchopt

from uqlib.utils import model_to_function
from uqlib.utils import linearized_forward_diag
Expand Down
103 changes: 103 additions & 0 deletions uqlib/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Type, NamedTuple, Any
from functools import partial
import torch

from uqlib.types import TensorTree, Transform, LogProbFn


class OptimState(NamedTuple):
"""State of an optimizer.
Args:
params: Parameters to be optimised.
optimizer: torch.optim optimizer instance.
loss: Loss value.
aux: Auxiliary information from the loss function call.
"""

params: TensorTree
optimizer: torch.optim.Optimizer
loss: torch.tensor = torch.tensor(0.0)
aux: Any = None


def init(
params: TensorTree,
optimizer_cls: Type[torch.optim.Optimizer],
*args: Any,
**kwargs: Any,
) -> OptimState:
"""Initialise an optimizer.
Args:
params: Parameters to be optimised.
optimizer_cls: Optimizer class from torch.optim.
*args: Positional arguments to pass to the optimizer class.
**kwargs: Keyword arguments to pass to the optimizer class.
Returns:
Initial OptimState.
"""
opt_params = [params] if isinstance(params, torch.Tensor) else params

optimizer = optimizer_cls(opt_params, *args, **kwargs)
return OptimState(params, optimizer)


def update(
state: OptimState,
batch: TensorTree,
loss_fn: LogProbFn,
inplace: bool = True,
) -> OptimState:
"""Perform a single update step of the optimizer.
Args:
state: Current optimizer state.
batch: Input data to loss_fn.
loss_fn: Function that takes the parameters and returns the loss.
of the form `loss, aux = fn(params, batch)`.
inplace: Whether to update the parameters in place.
inplace=False not supported for uqlib.optim
Returns:
Updated OptimState.
"""
if not inplace:
raise NotImplementedError("inplace=False not supported for uqlib.optim")
state.optimizer.zero_grad()
loss, aux = loss_fn(state.params, batch)
loss.backward()
state.optimizer.step()
return OptimState(state.params, state.optimizer, state.loss.detach(), aux)


def build(
loss_fn: LogProbFn,
optimizer: Type[torch.optim.Optimizer],
**kwargs: Any,
) -> Transform:
"""Builds an optimizer transform from torch.optim.
Example usage:
```
transform = build(loss_fn, torch.optim.Adam, lr=0.1)
state = transform.init(params)
for batch in dataloader:
state = transform.update(state, batch)
```
Arg:
loss_fn: Function that takes the parameters and returns the loss.
of the form `loss, aux = fn(params, batch)`.
optimizer: Optimizer class from torch.optim.
**kwargs: Keyword arguments to pass to the optimizer class.
Returns:
Optimizer transform (uqlib.types.Transform instance).
"""
init_fn = partial(init, optimizer_cls=optimizer, **kwargs)
update_fn = partial(update, loss_fn=loss_fn)
return Transform(init_fn, update_fn)
100 changes: 100 additions & 0 deletions uqlib/torchopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import NamedTuple, Any
from functools import partial
import torch
import torchopt

from uqlib.types import TensorTree, Transform, LogProbFn


class TorchOptState(NamedTuple):
"""State of a TorchOpt optimizer.
Args:
params: Parameters to be optimised.
opt_state: TorchOpt optimizer state.
loss: Loss value.
aux: Auxiliary information from the loss function call.
"""

params: TensorTree
opt_state: torch.optim.Optimizer
loss: torch.tensor = torch.tensor(0.0)
aux: Any = None


def init(
params: TensorTree,
optimizer: torchopt.base.GradientTransformation,
) -> TorchOptState:
"""Initialise a TorchOpt optimizer.
Args:
params: Parameters to be optimised.
optimizer: TorchOpt functional optimizer.
Make sure to use lower case like torchopt.adam()
Returns:
Initial TorchOptState.
"""
opt_state = optimizer.init(params)
return TorchOptState(params, opt_state)


def update(
state: TorchOptState,
batch: TensorTree,
loss_fn: LogProbFn,
optimizer: torchopt.base.GradientTransformation,
inplace: bool = True,
) -> TorchOptState:
"""Update the optimizer state.
Args:
state: Current state.
batch: Batch of data.
loss_fn: Loss function.
optimizer: TorchOpt functional optimizer.
Make sure to use lower case like torchopt.adam()
inplace: Whether to update the state in place.
Returns:
Updated state.
"""
params = state.params
opt_state = state.opt_state
with torch.no_grad():
grads, (loss, aux) = torch.func.grad_and_value(loss_fn, has_aux=True)(
params, batch
)
updates, opt_state = optimizer.update(grads, opt_state)
params = torchopt.apply_updates(params, updates, inplace=inplace)
return TorchOptState(params, opt_state, loss, aux)


def build(
loss_fn: LogProbFn,
optimizer: torchopt.base.GradientTransformation,
) -> Transform:
"""Build a TorchOpt optimizer transformation.
Example usage:
```
transform = build(loss_fn, torchopt.adam(lr=0.1))
state = transform.init(params)
for batch in dataloader:
state = transform.update(state, batch)
```
Args:
loss_fn: Loss function.
optimizer: TorchOpt functional optimizer.
Make sure to use lower case like torchopt.adam()
Returns:
Torchopt optimizer transform (uqlib.types.Transform instance).
"""
init_fn = partial(init, optimizer=optimizer)
update_fn = partial(update, optimizer=optimizer, loss_fn=loss_fn)
return Transform(init_fn, update_fn)

0 comments on commit cb47f7d

Please sign in to comment.