-
Notifications
You must be signed in to change notification settings - Fork 0
Add callbacks, history logging and performance improvements to sgd. #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
"""Container classes for outputs from each iteration of solver methods.""" | ||
|
||
from dataclasses import dataclass, field | ||
|
||
import numpy.typing as npt | ||
|
||
from causalprog.utils.norms import PyTree | ||
|
||
|
||
@dataclass(frozen=False) | ||
class IterationResult: | ||
""" | ||
Container class storing state of solvers at iteration `iters`. | ||
Args: | ||
fn_args: Argument to the objective function at final iteration (the solution, | ||
if `successful is `True`). | ||
grad_val: Value of the gradient of the objective function at the `fn_args`. | ||
iters: Number of iterations performed. | ||
obj_val: Value of the objective function at `fn_args`. | ||
iter_history: List of iteration numbers at which history was logged. | ||
fn_args_history: List of `fn_args` at each logged iteration. | ||
grad_val_history: List of `grad_val` at each logged iteration. | ||
obj_val_history: List of `obj_val` at each logged iteration. | ||
""" | ||
|
||
fn_args: PyTree | ||
grad_val: PyTree | ||
iters: int | ||
obj_val: npt.ArrayLike | ||
|
||
iter_history: list[int] = field(default_factory=list) | ||
fn_args_history: list[PyTree] = field(default_factory=list) | ||
grad_val_history: list[PyTree] = field(default_factory=list) | ||
obj_val_history: list[npt.ArrayLike] = field(default_factory=list) | ||
|
||
|
||
def _update_iteration_result( | ||
iter_result: IterationResult, | ||
current_params: PyTree, | ||
gradient_value: PyTree, | ||
iters: int, | ||
objective_value: npt.ArrayLike, | ||
history_logging_interval: int, | ||
) -> None: | ||
""" | ||
Update the `IterationResult` object with current iteration data. | ||
Only updates the history if `history_logging_interval` is positive and | ||
the current iteration is a multiple of `history_logging_interval`. | ||
""" | ||
iter_result.fn_args = current_params | ||
iter_result.grad_val = gradient_value | ||
iter_result.iters = iters | ||
iter_result.obj_val = objective_value | ||
|
||
if history_logging_interval > 0 and iters % history_logging_interval == 0: | ||
iter_result.iter_history.append(iters) | ||
iter_result.fn_args_history.append(current_params) | ||
iter_result.grad_val_history.append(gradient_value) | ||
iter_result.obj_val_history.append(objective_value) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,11 @@ | |
import numpy.typing as npt | ||
import optax | ||
|
||
from causalprog.solvers.iteration_result import ( | ||
IterationResult, | ||
_update_iteration_result, | ||
) | ||
from causalprog.solvers.solver_callbacks import _normalise_callbacks, _run_callbacks | ||
from causalprog.solvers.solver_result import SolverResult | ||
from causalprog.utils.norms import PyTree, l2_normsq | ||
|
||
|
@@ -23,6 +28,10 @@ def stochastic_gradient_descent( | |
maxiter: int = 1000, | ||
optimiser: optax.GradientTransformationExtraArgs | None = None, | ||
tolerance: float = 1.0e-8, | ||
history_logging_interval: int = -1, | ||
callbacks: Callable[[IterationResult], None] | ||
| list[Callable[[IterationResult], None]] | ||
| None = None, | ||
) -> SolverResult: | ||
""" | ||
Minimise a function of one argument using Stochastic Gradient Descent (SGD). | ||
|
@@ -65,12 +74,17 @@ def stochastic_gradient_descent( | |
this number of iterations is exceeded. | ||
optimiser: The `optax` optimiser to use during the update step. | ||
tolerance: `tolerance` used when determining if a minimum has been found. | ||
history_logging_interval: Interval (in number of iterations) at which to log | ||
the history of optimisation. If history_logging_interval <= 0, no | ||
history is logged. | ||
callbacks: A `callable` or list of `callables` that take an | ||
`IterationResult` as their only argument, and return `None`. | ||
These will be called at the end of each iteration of the optimisation | ||
procedure. | ||
|
||
|
||
Returns: | ||
Minimising argument of `obj_fn`. | ||
Value of `obj_fn` at the minimum. | ||
Gradient of `obj_fn` at the minimum. | ||
Number of iterations performed. | ||
SolverResult: Result of the optimisation procedure. | ||
|
||
""" | ||
if not fn_args: | ||
|
@@ -82,28 +96,49 @@ def stochastic_gradient_descent( | |
if not optimiser: | ||
optimiser = optax.adam(learning_rate) | ||
|
||
callbacks = _normalise_callbacks(callbacks) | ||
|
||
def objective(x: npt.ArrayLike) -> npt.ArrayLike: | ||
return obj_fn(x, *fn_args, **fn_kwargs) | ||
|
||
def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: | ||
return convergence_criteria(x, dx) < tolerance | ||
|
||
converged = False | ||
value_and_grad_fn = jax.jit(jax.value_and_grad(objective)) | ||
|
||
# init state | ||
opt_state = optimiser.init(initial_guess) | ||
current_params = deepcopy(initial_guess) | ||
gradient = jax.grad(objective) | ||
converged = False | ||
objective_value, gradient_value = value_and_grad_fn(current_params) | ||
|
||
iter_result = IterationResult( | ||
fn_args=current_params, | ||
grad_val=gradient_value, | ||
iters=0, | ||
obj_val=objective_value, | ||
) | ||
|
||
for _ in range(maxiter + 1): | ||
objective_value = objective(current_params) | ||
gradient_value = gradient(current_params) | ||
_update_iteration_result( | ||
iter_result, | ||
current_params, | ||
gradient_value, | ||
_, | ||
objective_value, | ||
history_logging_interval, | ||
) | ||
Comment on lines
122
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per Matt's convention from another PR, since we're using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, because of the |
||
|
||
_run_callbacks(iter_result, callbacks) | ||
|
||
if converged := is_converged(objective_value, gradient_value): | ||
break | ||
|
||
updates, opt_state = optimiser.update(gradient_value, opt_state) | ||
current_params = optax.apply_updates(current_params, updates) | ||
|
||
objective_value, gradient_value = value_and_grad_fn(current_params) | ||
|
||
iters_used = _ | ||
reason_msg = ( | ||
f"Did not converge after {iters_used} iterations" if not converged else "" | ||
|
@@ -117,4 +152,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: | |
obj_val=objective_value, | ||
reason=reason_msg, | ||
successful=converged, | ||
iter_history=iter_result.iter_history, | ||
fn_args_history=iter_result.fn_args_history, | ||
grad_val_history=iter_result.grad_val_history, | ||
obj_val_history=iter_result.obj_val_history, | ||
) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,61 @@ | ||||||
"""Module for callback functions for solvers.""" | ||||||
|
||||||
from collections.abc import Callable | ||||||
|
||||||
from tqdm.auto import tqdm | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this this or some other package I'm not aware of? Either way, if it's not part of the standard library, we should be depending on it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was just a mistake. I think the question is do we want to keep it in causalprog, and if so do we want it to be the default behaviour for these loops? We would benefit from some sort of default optimisation monitoring system, just so it's easy to know if something is broke. |
||||||
|
||||||
from causalprog.solvers.iteration_result import IterationResult | ||||||
|
||||||
|
||||||
def _normalise_callbacks( | ||||||
callbacks: Callable[[IterationResult], None] | ||||||
| list[Callable[[IterationResult], None]] | ||||||
| None = None, | ||||||
) -> list[Callable[[IterationResult], None]]: | ||||||
if callbacks is None: | ||||||
return [] | ||||||
if callable(callbacks): | ||||||
return [callbacks] | ||||||
if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think will be more robust - we could be passed a (In fact, I'd be somewhat tempted to go full Python "try and beg forgiveness" ethos and just do if all(callable(cb) for cb in callbacks): because if we get to this point, and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I went with lists because of course all sorts of things like strings and bytes are iterables. But now I think of it, it wouldn't make much sense to have a string of callables 😂. Would there be other weird cases that we've not considered? |
||||||
return callbacks | ||||||
|
||||||
msg = "Callbacks must be a callable or a sequence of callables" | ||||||
raise TypeError(msg) | ||||||
|
||||||
|
||||||
def _run_callbacks( | ||||||
iter_result: IterationResult, | ||||||
callbacks: list[Callable[[IterationResult], None]], | ||||||
) -> None: | ||||||
for cb in callbacks: | ||||||
cb(iter_result) | ||||||
|
||||||
|
||||||
def tqdm_callback(total: int) -> Callable[[IterationResult], None]: | ||||||
""" | ||||||
Progress bar callback using `tqdm`. | ||||||
Creates a callback function that can be passed to solvers to display a progress bar | ||||||
during optimization. The progress bar updates based on the number of iterations and | ||||||
also displays the current objective value. | ||||||
Args: | ||||||
total: Total number of iterations for the progress bar. | ||||||
Returns: | ||||||
Callback function that updates the progress bar. | ||||||
""" | ||||||
bar = tqdm(total=total) | ||||||
last_it = {"i": 0} | ||||||
|
||||||
def cb(ir: IterationResult) -> None: | ||||||
step = ir.iters - last_it["i"] | ||||||
if step > 0: | ||||||
bar.update(step) | ||||||
|
||||||
# Show objective and grad norm | ||||||
bar.set_postfix(obj=float(ir.obj_val)) | ||||||
last_it["i"] = ir.iters | ||||||
|
||||||
return cb |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
"""Container class for outputs from solver methods.""" | ||
|
||
from dataclasses import dataclass | ||
from dataclasses import dataclass, field | ||
|
||
import numpy.typing as npt | ||
|
||
|
@@ -26,6 +26,10 @@ class SolverResult: | |
successful: `True` if solver converged, in which case `fn_args` is the | ||
argument to the objective function at the solution of the problem being | ||
solved. `False` otherwise. | ||
iter_history: List of iteration numbers at which history was logged. | ||
fn_args_history: List of `fn_args` at each logged iteration. | ||
grad_val_history: List of `grad_val` at each logged iteration. | ||
obj_val_history: List of `obj_val` at each logged iteration. | ||
|
||
""" | ||
|
||
|
@@ -36,3 +40,8 @@ class SolverResult: | |
obj_val: npt.ArrayLike | ||
reason: str | ||
successful: bool | ||
|
||
iter_history: list[int] = field(default_factory=list) | ||
fn_args_history: list[PyTree] = field(default_factory=list) | ||
grad_val_history: list[PyTree] = field(default_factory=list) | ||
obj_val_history: list[npt.ArrayLike] = field(default_factory=list) | ||
Comment on lines
+43
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just subclass here as If we don't want to subclass; rather than adding all these attributes directly to the class, just add another attribute like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two things I'm thinking about here:
To summarise, I think we definitely want But that's just my opinion. I might be overlooking something that makes the above non-issues. Writing this out has also made me realise the need to make the history attributes mutable by the optimisation loop only and not by callbacks. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
|
||
from causalprog.solvers.iteration_result import IterationResult | ||
from causalprog.solvers.solver_callbacks import _normalise_callbacks | ||
|
||
|
||
def test_normalise_callbacks() -> None: | ||
"""Test that callbacks are normalised correctly.""" | ||
|
||
def callback(iter_result: IterationResult) -> None: | ||
pass | ||
|
||
# Test single callable | ||
assert _normalise_callbacks(callback) == [callback] | ||
|
||
# Test sequence of callables | ||
assert _normalise_callbacks([callback, callback]) == [callback, callback] | ||
|
||
# Test None | ||
assert _normalise_callbacks(None) == [] | ||
|
||
# Test empty sequence | ||
assert _normalise_callbacks([]) == [] | ||
|
||
# Test invalid input | ||
with pytest.raises( | ||
TypeError, match="Callbacks must be a callable or a sequence of callables" | ||
): | ||
_normalise_callbacks(42) # type: ignore[arg-type] | ||
Comment on lines
+7
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All good tests. Just flagging that we're favouring the parametrised approach rather than nesting multiple cases (in the event the setup is cheap). We also have a So would suggest something like def placeholder_callback(_: IterationResult) -> None:
pass
@pytest.mark.parametrize(
("input", "expected"),
[
pytest.param([], [], id="Empty list",
pytest.param(placeholder_callback, [placeholder_callback], id="Single callable",
pytest.param(42, TypeError("Callbacks...."), id="..."),
# Add other test cases as necessary
]
)
def test_normalise_callbacks(input, expected, raises_error) -> None:
"""Test that callbacks are normalised correctly."""
if isinstance(expected, Exception):
with raises_context(expected):
_normalise_callbacks(input)
else:
assert _normalise_callbacks(input) == expected (with typehints etc added as needed). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why this isn't a method of the
IterationResult
class? I expected this to be something likeIterationResult.update
(and it's currently written in that form too -> swapiter_result
toself
).Related, any reason for why
history_logging_interval
is an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just checkhistory_logging_interval > 0
once at creation time, and not do it every time in the method.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. It should be.
No. This is a good point.