diff --git a/src/causalprog/solvers/iteration_result.py b/src/causalprog/solvers/iteration_result.py new file mode 100644 index 0000000..ae555f8 --- /dev/null +++ b/src/causalprog/solvers/iteration_result.py @@ -0,0 +1,63 @@ +"""Container classes for outputs from each iteration of solver methods.""" + +from dataclasses import dataclass, field + +import numpy.typing as npt + +from causalprog.utils.norms import PyTree + + +@dataclass(frozen=False) +class IterationResult: + """ + Container class storing state of solvers at iteration `iters`. + + Args: + fn_args: Argument to the objective function at final iteration (the solution, + if `successful is `True`). + grad_val: Value of the gradient of the objective function at the `fn_args`. + iters: Number of iterations performed. + obj_val: Value of the objective function at `fn_args`. + iter_history: List of iteration numbers at which history was logged. + fn_args_history: List of `fn_args` at each logged iteration. + grad_val_history: List of `grad_val` at each logged iteration. + obj_val_history: List of `obj_val` at each logged iteration. + + """ + + fn_args: PyTree + grad_val: PyTree + iters: int + obj_val: npt.ArrayLike + + iter_history: list[int] = field(default_factory=list) + fn_args_history: list[PyTree] = field(default_factory=list) + grad_val_history: list[PyTree] = field(default_factory=list) + obj_val_history: list[npt.ArrayLike] = field(default_factory=list) + + +def _update_iteration_result( + iter_result: IterationResult, + current_params: PyTree, + gradient_value: PyTree, + iters: int, + objective_value: npt.ArrayLike, + history_logging_interval: int, +) -> None: + """ + Update the `IterationResult` object with current iteration data. + + Only updates the history if `history_logging_interval` is positive and + the current iteration is a multiple of `history_logging_interval`. + + """ + iter_result.fn_args = current_params + iter_result.grad_val = gradient_value + iter_result.iters = iters + iter_result.obj_val = objective_value + + if history_logging_interval > 0 and iters % history_logging_interval == 0: + iter_result.iter_history.append(iters) + iter_result.fn_args_history.append(current_params) + iter_result.grad_val_history.append(gradient_value) + iter_result.obj_val_history.append(objective_value) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index 141d5e3..08526b9 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -8,6 +8,11 @@ import numpy.typing as npt import optax +from causalprog.solvers.iteration_result import ( + IterationResult, + _update_iteration_result, +) +from causalprog.solvers.solver_callbacks import _normalise_callbacks, _run_callbacks from causalprog.solvers.solver_result import SolverResult from causalprog.utils.norms import PyTree, l2_normsq @@ -23,6 +28,10 @@ def stochastic_gradient_descent( maxiter: int = 1000, optimiser: optax.GradientTransformationExtraArgs | None = None, tolerance: float = 1.0e-8, + history_logging_interval: int = -1, + callbacks: Callable[[IterationResult], None] + | list[Callable[[IterationResult], None]] + | None = None, ) -> SolverResult: """ Minimise a function of one argument using Stochastic Gradient Descent (SGD). @@ -65,12 +74,17 @@ def stochastic_gradient_descent( this number of iterations is exceeded. optimiser: The `optax` optimiser to use during the update step. tolerance: `tolerance` used when determining if a minimum has been found. + history_logging_interval: Interval (in number of iterations) at which to log + the history of optimisation. If history_logging_interval <= 0, no + history is logged. + callbacks: A `callable` or list of `callables` that take an + `IterationResult` as their only argument, and return `None`. + These will be called at the end of each iteration of the optimisation + procedure. + Returns: - Minimising argument of `obj_fn`. - Value of `obj_fn` at the minimum. - Gradient of `obj_fn` at the minimum. - Number of iterations performed. + SolverResult: Result of the optimisation procedure. """ if not fn_args: @@ -82,21 +96,40 @@ def stochastic_gradient_descent( if not optimiser: optimiser = optax.adam(learning_rate) + callbacks = _normalise_callbacks(callbacks) + def objective(x: npt.ArrayLike) -> npt.ArrayLike: return obj_fn(x, *fn_args, **fn_kwargs) def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: return convergence_criteria(x, dx) < tolerance - converged = False + value_and_grad_fn = jax.jit(jax.value_and_grad(objective)) + # init state opt_state = optimiser.init(initial_guess) current_params = deepcopy(initial_guess) - gradient = jax.grad(objective) + converged = False + objective_value, gradient_value = value_and_grad_fn(current_params) + + iter_result = IterationResult( + fn_args=current_params, + grad_val=gradient_value, + iters=0, + obj_val=objective_value, + ) for _ in range(maxiter + 1): - objective_value = objective(current_params) - gradient_value = gradient(current_params) + _update_iteration_result( + iter_result, + current_params, + gradient_value, + _, + objective_value, + history_logging_interval, + ) + + _run_callbacks(iter_result, callbacks) if converged := is_converged(objective_value, gradient_value): break @@ -104,6 +137,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: updates, opt_state = optimiser.update(gradient_value, opt_state) current_params = optax.apply_updates(current_params, updates) + objective_value, gradient_value = value_and_grad_fn(current_params) + iters_used = _ reason_msg = ( f"Did not converge after {iters_used} iterations" if not converged else "" @@ -117,4 +152,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: obj_val=objective_value, reason=reason_msg, successful=converged, + iter_history=iter_result.iter_history, + fn_args_history=iter_result.fn_args_history, + grad_val_history=iter_result.grad_val_history, + obj_val_history=iter_result.obj_val_history, ) diff --git a/src/causalprog/solvers/solver_callbacks.py b/src/causalprog/solvers/solver_callbacks.py new file mode 100644 index 0000000..482d601 --- /dev/null +++ b/src/causalprog/solvers/solver_callbacks.py @@ -0,0 +1,61 @@ +"""Module for callback functions for solvers.""" + +from collections.abc import Callable + +from tqdm.auto import tqdm + +from causalprog.solvers.iteration_result import IterationResult + + +def _normalise_callbacks( + callbacks: Callable[[IterationResult], None] + | list[Callable[[IterationResult], None]] + | None = None, +) -> list[Callable[[IterationResult], None]]: + if callbacks is None: + return [] + if callable(callbacks): + return [callbacks] + if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks): + return callbacks + + msg = "Callbacks must be a callable or a sequence of callables" + raise TypeError(msg) + + +def _run_callbacks( + iter_result: IterationResult, + callbacks: list[Callable[[IterationResult], None]], +) -> None: + for cb in callbacks: + cb(iter_result) + + +def tqdm_callback(total: int) -> Callable[[IterationResult], None]: + """ + Progress bar callback using `tqdm`. + + Creates a callback function that can be passed to solvers to display a progress bar + during optimization. The progress bar updates based on the number of iterations and + also displays the current objective value. + + Args: + total: Total number of iterations for the progress bar. + + Returns: + Callback function that updates the progress bar. + + """ + bar = tqdm(total=total) + last_it = {"i": 0} + + def cb(ir: IterationResult) -> None: + step = ir.iters - last_it["i"] + if step > 0: + bar.update(step) + + # Show objective and grad norm + bar.set_postfix(obj=float(ir.obj_val)) + last_it["i"] = ir.iters + + return cb diff --git a/src/causalprog/solvers/solver_result.py b/src/causalprog/solvers/solver_result.py index eb09457..e2fe8f6 100644 --- a/src/causalprog/solvers/solver_result.py +++ b/src/causalprog/solvers/solver_result.py @@ -1,6 +1,6 @@ """Container class for outputs from solver methods.""" -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy.typing as npt @@ -26,6 +26,10 @@ class SolverResult: successful: `True` if solver converged, in which case `fn_args` is the argument to the objective function at the solution of the problem being solved. `False` otherwise. + iter_history: List of iteration numbers at which history was logged. + fn_args_history: List of `fn_args` at each logged iteration. + grad_val_history: List of `grad_val` at each logged iteration. + obj_val_history: List of `obj_val` at each logged iteration. """ @@ -36,3 +40,8 @@ class SolverResult: obj_val: npt.ArrayLike reason: str successful: bool + + iter_history: list[int] = field(default_factory=list) + fn_args_history: list[PyTree] = field(default_factory=list) + grad_val_history: list[PyTree] = field(default_factory=list) + obj_val_history: list[npt.ArrayLike] = field(default_factory=list) diff --git a/tests/test_solvers/test_normalise_solver_inputs.py b/tests/test_solvers/test_normalise_solver_inputs.py new file mode 100644 index 0000000..ca1e32f --- /dev/null +++ b/tests/test_solvers/test_normalise_solver_inputs.py @@ -0,0 +1,29 @@ +import pytest + +from causalprog.solvers.iteration_result import IterationResult +from causalprog.solvers.solver_callbacks import _normalise_callbacks + + +def test_normalise_callbacks() -> None: + """Test that callbacks are normalised correctly.""" + + def callback(iter_result: IterationResult) -> None: + pass + + # Test single callable + assert _normalise_callbacks(callback) == [callback] + + # Test sequence of callables + assert _normalise_callbacks([callback, callback]) == [callback, callback] + + # Test None + assert _normalise_callbacks(None) == [] + + # Test empty sequence + assert _normalise_callbacks([]) == [] + + # Test invalid input + with pytest.raises( + TypeError, match="Callbacks must be a callable or a sequence of callables" + ): + _normalise_callbacks(42) # type: ignore[arg-type] diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index f602ee6..768a194 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -6,6 +6,7 @@ import numpy.typing as npt import pytest +from causalprog.solvers.iteration_result import IterationResult from causalprog.solvers.sgd import stochastic_gradient_descent from causalprog.utils.norms import PyTree @@ -87,3 +88,260 @@ def test_sgd( assert jax.tree_util.tree_all( jax.tree_util.tree_map(jax.numpy.allclose, result.fn_args, expected) ) + + +@pytest.mark.parametrize( + ( + "history_logging_interval", + "expected_iters", + ), + [ + pytest.param( + 1, + list(range(11)), + id="interval=1", + ), + pytest.param( + 2, + list(range(0, 11, 2)), + id="interval=2", + ), + pytest.param( + 3, + list(range(0, 11, 3)), + id="interval=3", + ), + pytest.param( + 0, + [], + id="interval=0 (no logging)", + ), + pytest.param( + -1, + [], + id="interval=-1 (no logging)", + ), + ], +) +def test_sgd_history_logging_intervals( + history_logging_interval: int, expected_iters: list[int] +) -> None: + """Test that history logging intervals work correctly.""" + + def obj_fn(x): + return (x**2).sum() + + initial_guess = jnp.atleast_1d(1.0) + + result = stochastic_gradient_descent( + obj_fn, + initial_guess, + maxiter=10, + tolerance=0.0, + history_logging_interval=history_logging_interval, + ) + + # Check that the correct iterations were logged + assert result.iter_history == expected_iters, ( + f"IterationResult.iter_history logged incorrectly. Expected {expected_iters}." + f"Got {result.iter_history}" + ) + + # Check that a correct number of fn_args, grad_val, obj_val were logged + assert len(result.fn_args_history) == len(expected_iters), ( + "IterationResult.fn_args_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.fn_args_history)}" + ) + assert len(result.grad_val_history) == len(expected_iters), ( + "IterationResult.grad_val_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.grad_val_history)}" + ) + assert len(result.obj_val_history) == len(expected_iters), ( + "IterationResult.obj_val_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.obj_val_history)}" + ) + + # Check that logged fn_args, grad_val, obj_val line up correctly + value_and_grad_fn = jax.jit(jax.value_and_grad(obj_fn)) + + if len(expected_iters) > 0: + for fn_args, obj_val, grad_val in zip( + result.fn_args_history, + result.obj_val_history, + result.grad_val_history, + strict=True, + ): + real_obj_val, real_grad_val = value_and_grad_fn(fn_args) + + # Check that logged obj_val and fn_args line up correctly + assert real_obj_val == obj_val, ( + "Logged obj_val does not match obj_fn evaluated at logged fn_args." + f"For fn_args {fn_args}, we expected {obj_fn(fn_args)}, got {obj_val}." + ) + + # Check that logged gradient and fn_args line up correctly + assert real_grad_val == grad_val, ( + "Logged grad_val does not match gradient of obj_fn evaluated at" + f" logged fn_args. For fn_args {fn_args}, we expected" + f" {jax.gradient(obj_fn)(fn_args)}, got {grad_val}." + ) + + +@pytest.mark.parametrize( + ( + "make_callbacks", + "expected", + ), + [ + ( + lambda cb: cb, + [0, 1, 2], + ), + ( + lambda cb: [cb], + [0, 1, 2], + ), + ( + lambda cb: [cb, cb], + [0, 0, 1, 1, 2, 2], + ), + ( + lambda cb: [], # noqa: ARG005 + [], + ), + ( + lambda cb: None, # noqa: ARG005 + [], + ), + ], + ids=[ + "single callable", + "list of one callable", + "list of two callables", + "callbacks=[]", + "callbacks=None", + ], +) +def test_sgd_callbacks_invocation( + make_callbacks: Callable, expected: list[int] +) -> None: + """Test SGD invokes callbacks correctly for all shapes of callbacks input.""" + + def obj_fn(x): + return (x**2).sum() + + calls = [] + + def callback(iter_result: IterationResult) -> None: + calls.append(iter_result.iters) + + callbacks = make_callbacks(callback) + + initial = jnp.atleast_1d(1.0) + + stochastic_gradient_descent( + obj_fn, + initial, + maxiter=2, + tolerance=0.0, + callbacks=callbacks, + ) + + assert calls == expected, ( + f"Callback was not called correctly, got {calls}, expected {expected}" + ) + + +def test_sgd_invalid_callback() -> None: + def obj_fn(x): + return (x**2).sum() + + initial = jnp.atleast_1d(1.0) + + with pytest.raises( + TypeError, match="Callbacks must be a callable or a sequence of callables" + ): + stochastic_gradient_descent( + obj_fn, + initial, + maxiter=2, + tolerance=0.0, + callbacks=42, # type: ignore[arg-type] + ) + + +@pytest.mark.parametrize( + "history_logging_interval", [0, 1, 2], ids=lambda v: f"hist:{v}" +) +@pytest.mark.parametrize( + "make_callbacks", + [ + lambda cb: cb, + lambda cb: [cb], + lambda cb: [cb, cb], + lambda cb: [], # noqa: ARG005 + lambda cb: None, # noqa: ARG005 + ], + ids=["callable", "list_1", "list_2", "empty", "none"], +) +def test_logging_or_callbacks_affect_sgd_convergence( + history_logging_interval, + make_callbacks, +) -> None: + """Test that logging and callbacks don't affect convergence of SGD solver.""" + calls = [] + + def callback(iter_result: IterationResult) -> None: + calls.append(iter_result.iters) + + callbacks = make_callbacks(callback) + + def obj_fn(x): + return (x**2).sum() + + initial_guess = jnp.atleast_1d(1.0) + + baseline_result = stochastic_gradient_descent( + obj_fn, + initial_guess, + maxiter=6, + tolerance=0.0, + history_logging_interval=0, + ) + + result = stochastic_gradient_descent( + obj_fn, + initial_guess, + maxiter=6, + tolerance=0.0, + history_logging_interval=history_logging_interval, + callbacks=callbacks, + ) + + baseline_attributes = [ + baseline_result.fn_args, + baseline_result.obj_val, + baseline_result.grad_val, + baseline_result.iters, + baseline_result.successful, + baseline_result.reason, + ] + + result_attributes = [ + result.fn_args, + result.obj_val, + result.grad_val, + result.iters, + result.successful, + result.reason, + ] + + for baseline_attr, result_attr in zip( + baseline_attributes, result_attributes, strict=True + ): + assert baseline_attr == result_attr, ( + "Logging or callbacks changed the convergence behaviour of the" + " solver. For history_logging_interval" + f" {history_logging_interval}, callbacks {callbacks}, expected" + f" {baseline_attributes}, got {result_attributes}" + )