diff --git a/tests/test_optim.py b/tests/test_optim.py new file mode 100644 index 00000000..8c3658db --- /dev/null +++ b/tests/test_optim.py @@ -0,0 +1,22 @@ +import torch + +import uqlib + + +def test_optim(): + optimizer_cls = torch.optim.SGD + lr = 0.1 + + def loss_fn(p, b): + return torch.sum(p**2), torch.tensor([]) + + transform = uqlib.optim.build(loss_fn, optimizer_cls, lr=lr) + + params = torch.tensor([1.0], requires_grad=True) + state = transform.init(params) + + for _ in range(100): + state = transform.update(state, torch.tensor([1.0])) + + assert state.loss < 1e-3 + assert state.params < 1e-3 diff --git a/tests/test_torchopt.py b/tests/test_torchopt.py new file mode 100644 index 00000000..7a1b9956 --- /dev/null +++ b/tests/test_torchopt.py @@ -0,0 +1,22 @@ +import torch +import torchopt + +import uqlib + + +def test_torchopt(): + optimizer = torchopt.sgd(lr=0.1) + + def loss_fn(p, b): + return torch.sum(p**2), torch.tensor([]) + + transform = uqlib.torchopt.build(loss_fn, optimizer) + + params = torch.tensor([1.0], requires_grad=True) + state = transform.init(params) + + for _ in range(100): + state = transform.update(state, torch.tensor([1.0])) + + assert state.loss < 1e-3 + assert state.params < 1e-3 diff --git a/uqlib/__init__.py b/uqlib/__init__.py index 0edd023d..3d7f35cd 100644 --- a/uqlib/__init__.py +++ b/uqlib/__init__.py @@ -3,6 +3,8 @@ from uqlib import sgmcmc from uqlib import types from uqlib import vi +from uqlib import optim +from uqlib import torchopt from uqlib.utils import model_to_function from uqlib.utils import linearized_forward_diag diff --git a/uqlib/optim.py b/uqlib/optim.py new file mode 100644 index 00000000..774b2ab9 --- /dev/null +++ b/uqlib/optim.py @@ -0,0 +1,103 @@ +from typing import Type, NamedTuple, Any +from functools import partial +import torch + +from uqlib.types import TensorTree, Transform, LogProbFn + + +class OptimState(NamedTuple): + """State of an optimizer. + + Args: + params: Parameters to be optimised. + optimizer: torch.optim optimizer instance. + loss: Loss value. + aux: Auxiliary information from the loss function call. + """ + + params: TensorTree + optimizer: torch.optim.Optimizer + loss: torch.tensor = torch.tensor(0.0) + aux: Any = None + + +def init( + params: TensorTree, + optimizer_cls: Type[torch.optim.Optimizer], + *args: Any, + **kwargs: Any, +) -> OptimState: + """Initialise an optimizer. + + Args: + params: Parameters to be optimised. + optimizer_cls: Optimizer class from torch.optim. + *args: Positional arguments to pass to the optimizer class. + **kwargs: Keyword arguments to pass to the optimizer class. + + Returns: + Initial OptimState. + """ + opt_params = [params] if isinstance(params, torch.Tensor) else params + + optimizer = optimizer_cls(opt_params, *args, **kwargs) + return OptimState(params, optimizer) + + +def update( + state: OptimState, + batch: TensorTree, + loss_fn: LogProbFn, + inplace: bool = True, +) -> OptimState: + """Perform a single update step of the optimizer. + + Args: + state: Current optimizer state. + batch: Input data to loss_fn. + loss_fn: Function that takes the parameters and returns the loss. + of the form `loss, aux = fn(params, batch)`. + inplace: Whether to update the parameters in place. + inplace=False not supported for uqlib.optim + + Returns: + Updated OptimState. + """ + if not inplace: + raise NotImplementedError("inplace=False not supported for uqlib.optim") + state.optimizer.zero_grad() + loss, aux = loss_fn(state.params, batch) + loss.backward() + state.optimizer.step() + return OptimState(state.params, state.optimizer, state.loss.detach(), aux) + + +def build( + loss_fn: LogProbFn, + optimizer: Type[torch.optim.Optimizer], + **kwargs: Any, +) -> Transform: + """Builds an optimizer transform from torch.optim. + + Example usage: + + ``` + transform = build(loss_fn, torch.optim.Adam, lr=0.1) + state = transform.init(params) + + for batch in dataloader: + state = transform.update(state, batch) + ``` + + Arg: + loss_fn: Function that takes the parameters and returns the loss. + of the form `loss, aux = fn(params, batch)`. + optimizer: Optimizer class from torch.optim. + **kwargs: Keyword arguments to pass to the optimizer class. + + Returns: + Optimizer transform (uqlib.types.Transform instance). + """ + init_fn = partial(init, optimizer_cls=optimizer, **kwargs) + update_fn = partial(update, loss_fn=loss_fn) + return Transform(init_fn, update_fn) diff --git a/uqlib/torchopt.py b/uqlib/torchopt.py new file mode 100644 index 00000000..57216c42 --- /dev/null +++ b/uqlib/torchopt.py @@ -0,0 +1,100 @@ +from typing import NamedTuple, Any +from functools import partial +import torch +import torchopt + +from uqlib.types import TensorTree, Transform, LogProbFn + + +class TorchOptState(NamedTuple): + """State of a TorchOpt optimizer. + + Args: + params: Parameters to be optimised. + opt_state: TorchOpt optimizer state. + loss: Loss value. + aux: Auxiliary information from the loss function call. + """ + + params: TensorTree + opt_state: torch.optim.Optimizer + loss: torch.tensor = torch.tensor(0.0) + aux: Any = None + + +def init( + params: TensorTree, + optimizer: torchopt.base.GradientTransformation, +) -> TorchOptState: + """Initialise a TorchOpt optimizer. + + Args: + params: Parameters to be optimised. + optimizer: TorchOpt functional optimizer. + Make sure to use lower case like torchopt.adam() + + Returns: + Initial TorchOptState. + """ + opt_state = optimizer.init(params) + return TorchOptState(params, opt_state) + + +def update( + state: TorchOptState, + batch: TensorTree, + loss_fn: LogProbFn, + optimizer: torchopt.base.GradientTransformation, + inplace: bool = True, +) -> TorchOptState: + """Update the optimizer state. + + Args: + state: Current state. + batch: Batch of data. + loss_fn: Loss function. + optimizer: TorchOpt functional optimizer. + Make sure to use lower case like torchopt.adam() + inplace: Whether to update the state in place. + + Returns: + Updated state. + """ + params = state.params + opt_state = state.opt_state + with torch.no_grad(): + grads, (loss, aux) = torch.func.grad_and_value(loss_fn, has_aux=True)( + params, batch + ) + updates, opt_state = optimizer.update(grads, opt_state) + params = torchopt.apply_updates(params, updates, inplace=inplace) + return TorchOptState(params, opt_state, loss, aux) + + +def build( + loss_fn: LogProbFn, + optimizer: torchopt.base.GradientTransformation, +) -> Transform: + """Build a TorchOpt optimizer transformation. + + Example usage: + + ``` + transform = build(loss_fn, torchopt.adam(lr=0.1)) + state = transform.init(params) + + for batch in dataloader: + state = transform.update(state, batch) + ``` + + Args: + loss_fn: Loss function. + optimizer: TorchOpt functional optimizer. + Make sure to use lower case like torchopt.adam() + + Returns: + Torchopt optimizer transform (uqlib.types.Transform instance). + """ + init_fn = partial(init, optimizer=optimizer) + update_fn = partial(update, optimizer=optimizer, loss_fn=loss_fn) + return Transform(init_fn, update_fn)