-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from normal-computing/add-opt
Add optimizers to unified API
- Loading branch information
Showing
5 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
|
||
import uqlib | ||
|
||
|
||
def test_optim(): | ||
optimizer_cls = torch.optim.SGD | ||
lr = 0.1 | ||
|
||
def loss_fn(p, b): | ||
return torch.sum(p**2), torch.tensor([]) | ||
|
||
transform = uqlib.optim.build(loss_fn, optimizer_cls, lr=lr) | ||
|
||
params = torch.tensor([1.0], requires_grad=True) | ||
state = transform.init(params) | ||
|
||
for _ in range(100): | ||
state = transform.update(state, torch.tensor([1.0])) | ||
|
||
assert state.loss < 1e-3 | ||
assert state.params < 1e-3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
import torchopt | ||
|
||
import uqlib | ||
|
||
|
||
def test_torchopt(): | ||
optimizer = torchopt.sgd(lr=0.1) | ||
|
||
def loss_fn(p, b): | ||
return torch.sum(p**2), torch.tensor([]) | ||
|
||
transform = uqlib.torchopt.build(loss_fn, optimizer) | ||
|
||
params = torch.tensor([1.0], requires_grad=True) | ||
state = transform.init(params) | ||
|
||
for _ in range(100): | ||
state = transform.update(state, torch.tensor([1.0])) | ||
|
||
assert state.loss < 1e-3 | ||
assert state.params < 1e-3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing import Type, NamedTuple, Any | ||
from functools import partial | ||
import torch | ||
|
||
from uqlib.types import TensorTree, Transform, LogProbFn | ||
|
||
|
||
class OptimState(NamedTuple): | ||
"""State of an optimizer. | ||
Args: | ||
params: Parameters to be optimised. | ||
optimizer: torch.optim optimizer instance. | ||
loss: Loss value. | ||
aux: Auxiliary information from the loss function call. | ||
""" | ||
|
||
params: TensorTree | ||
optimizer: torch.optim.Optimizer | ||
loss: torch.tensor = torch.tensor(0.0) | ||
aux: Any = None | ||
|
||
|
||
def init( | ||
params: TensorTree, | ||
optimizer_cls: Type[torch.optim.Optimizer], | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> OptimState: | ||
"""Initialise an optimizer. | ||
Args: | ||
params: Parameters to be optimised. | ||
optimizer_cls: Optimizer class from torch.optim. | ||
*args: Positional arguments to pass to the optimizer class. | ||
**kwargs: Keyword arguments to pass to the optimizer class. | ||
Returns: | ||
Initial OptimState. | ||
""" | ||
opt_params = [params] if isinstance(params, torch.Tensor) else params | ||
|
||
optimizer = optimizer_cls(opt_params, *args, **kwargs) | ||
return OptimState(params, optimizer) | ||
|
||
|
||
def update( | ||
state: OptimState, | ||
batch: TensorTree, | ||
loss_fn: LogProbFn, | ||
inplace: bool = True, | ||
) -> OptimState: | ||
"""Perform a single update step of the optimizer. | ||
Args: | ||
state: Current optimizer state. | ||
batch: Input data to loss_fn. | ||
loss_fn: Function that takes the parameters and returns the loss. | ||
of the form `loss, aux = fn(params, batch)`. | ||
inplace: Whether to update the parameters in place. | ||
inplace=False not supported for uqlib.optim | ||
Returns: | ||
Updated OptimState. | ||
""" | ||
if not inplace: | ||
raise NotImplementedError("inplace=False not supported for uqlib.optim") | ||
state.optimizer.zero_grad() | ||
loss, aux = loss_fn(state.params, batch) | ||
loss.backward() | ||
state.optimizer.step() | ||
return OptimState(state.params, state.optimizer, state.loss.detach(), aux) | ||
|
||
|
||
def build( | ||
loss_fn: LogProbFn, | ||
optimizer: Type[torch.optim.Optimizer], | ||
**kwargs: Any, | ||
) -> Transform: | ||
"""Builds an optimizer transform from torch.optim. | ||
Example usage: | ||
``` | ||
transform = build(loss_fn, torch.optim.Adam, lr=0.1) | ||
state = transform.init(params) | ||
for batch in dataloader: | ||
state = transform.update(state, batch) | ||
``` | ||
Arg: | ||
loss_fn: Function that takes the parameters and returns the loss. | ||
of the form `loss, aux = fn(params, batch)`. | ||
optimizer: Optimizer class from torch.optim. | ||
**kwargs: Keyword arguments to pass to the optimizer class. | ||
Returns: | ||
Optimizer transform (uqlib.types.Transform instance). | ||
""" | ||
init_fn = partial(init, optimizer_cls=optimizer, **kwargs) | ||
update_fn = partial(update, loss_fn=loss_fn) | ||
return Transform(init_fn, update_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import NamedTuple, Any | ||
from functools import partial | ||
import torch | ||
import torchopt | ||
|
||
from uqlib.types import TensorTree, Transform, LogProbFn | ||
|
||
|
||
class TorchOptState(NamedTuple): | ||
"""State of a TorchOpt optimizer. | ||
Args: | ||
params: Parameters to be optimised. | ||
opt_state: TorchOpt optimizer state. | ||
loss: Loss value. | ||
aux: Auxiliary information from the loss function call. | ||
""" | ||
|
||
params: TensorTree | ||
opt_state: torch.optim.Optimizer | ||
loss: torch.tensor = torch.tensor(0.0) | ||
aux: Any = None | ||
|
||
|
||
def init( | ||
params: TensorTree, | ||
optimizer: torchopt.base.GradientTransformation, | ||
) -> TorchOptState: | ||
"""Initialise a TorchOpt optimizer. | ||
Args: | ||
params: Parameters to be optimised. | ||
optimizer: TorchOpt functional optimizer. | ||
Make sure to use lower case like torchopt.adam() | ||
Returns: | ||
Initial TorchOptState. | ||
""" | ||
opt_state = optimizer.init(params) | ||
return TorchOptState(params, opt_state) | ||
|
||
|
||
def update( | ||
state: TorchOptState, | ||
batch: TensorTree, | ||
loss_fn: LogProbFn, | ||
optimizer: torchopt.base.GradientTransformation, | ||
inplace: bool = True, | ||
) -> TorchOptState: | ||
"""Update the optimizer state. | ||
Args: | ||
state: Current state. | ||
batch: Batch of data. | ||
loss_fn: Loss function. | ||
optimizer: TorchOpt functional optimizer. | ||
Make sure to use lower case like torchopt.adam() | ||
inplace: Whether to update the state in place. | ||
Returns: | ||
Updated state. | ||
""" | ||
params = state.params | ||
opt_state = state.opt_state | ||
with torch.no_grad(): | ||
grads, (loss, aux) = torch.func.grad_and_value(loss_fn, has_aux=True)( | ||
params, batch | ||
) | ||
updates, opt_state = optimizer.update(grads, opt_state) | ||
params = torchopt.apply_updates(params, updates, inplace=inplace) | ||
return TorchOptState(params, opt_state, loss, aux) | ||
|
||
|
||
def build( | ||
loss_fn: LogProbFn, | ||
optimizer: torchopt.base.GradientTransformation, | ||
) -> Transform: | ||
"""Build a TorchOpt optimizer transformation. | ||
Example usage: | ||
``` | ||
transform = build(loss_fn, torchopt.adam(lr=0.1)) | ||
state = transform.init(params) | ||
for batch in dataloader: | ||
state = transform.update(state, batch) | ||
``` | ||
Args: | ||
loss_fn: Loss function. | ||
optimizer: TorchOpt functional optimizer. | ||
Make sure to use lower case like torchopt.adam() | ||
Returns: | ||
Torchopt optimizer transform (uqlib.types.Transform instance). | ||
""" | ||
init_fn = partial(init, optimizer=optimizer) | ||
update_fn = partial(update, optimizer=optimizer, loss_fn=loss_fn) | ||
return Transform(init_fn, update_fn) |