Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/causalprog/solvers/iteration_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Container classes for outputs from each iteration of solver methods."""

from dataclasses import dataclass, field

import numpy.typing as npt

from causalprog.utils.norms import PyTree


@dataclass(frozen=False)
class IterationResult:
"""
Container class storing state of solvers at iteration `iters`.
Args:
fn_args: Argument to the objective function at final iteration (the solution,
if `successful is `True`).
grad_val: Value of the gradient of the objective function at the `fn_args`.
iters: Number of iterations performed.
obj_val: Value of the objective function at `fn_args`.
iter_history: List of iteration numbers at which history was logged.
fn_args_history: List of `fn_args` at each logged iteration.
grad_val_history: List of `grad_val` at each logged iteration.
obj_val_history: List of `obj_val` at each logged iteration.
"""

fn_args: PyTree
grad_val: PyTree
iters: int
obj_val: npt.ArrayLike

iter_history: list[int] = field(default_factory=list)
fn_args_history: list[PyTree] = field(default_factory=list)
grad_val_history: list[PyTree] = field(default_factory=list)
obj_val_history: list[npt.ArrayLike] = field(default_factory=list)


def _update_iteration_result(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this isn't a method of the IterationResult class? I expected this to be something like IterationResult.update (and it's currently written in that form too -> swap iter_result to self).

Related, any reason for why history_logging_interval is an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just check history_logging_interval > 0 once at creation time, and not do it every time in the method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this isn't a method of the IterationResult class? I expected this to be something like IterationResult.update (and it's currently written in that form too -> swap iter_result to self).

Nope. It should be.

Related, any reason for why history_logging_interval is an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just check history_logging_interval > 0 once at creation time, and not do it every time in the method.

No. This is a good point.

iter_result: IterationResult,
current_params: PyTree,
gradient_value: PyTree,
iters: int,
objective_value: npt.ArrayLike,
history_logging_interval: int,
) -> None:
"""
Update the `IterationResult` object with current iteration data.
Only updates the history if `history_logging_interval` is positive and
the current iteration is a multiple of `history_logging_interval`.
"""
iter_result.fn_args = current_params
iter_result.grad_val = gradient_value
iter_result.iters = iters
iter_result.obj_val = objective_value

if history_logging_interval > 0 and iters % history_logging_interval == 0:
iter_result.iter_history.append(iters)
iter_result.fn_args_history.append(current_params)
iter_result.grad_val_history.append(gradient_value)
iter_result.obj_val_history.append(objective_value)
55 changes: 47 additions & 8 deletions src/causalprog/solvers/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import numpy.typing as npt
import optax

from causalprog.solvers.iteration_result import (
IterationResult,
_update_iteration_result,
)
from causalprog.solvers.solver_callbacks import _normalise_callbacks, _run_callbacks
from causalprog.solvers.solver_result import SolverResult
from causalprog.utils.norms import PyTree, l2_normsq

Expand All @@ -23,6 +28,10 @@ def stochastic_gradient_descent(
maxiter: int = 1000,
optimiser: optax.GradientTransformationExtraArgs | None = None,
tolerance: float = 1.0e-8,
history_logging_interval: int = -1,
callbacks: Callable[[IterationResult], None]
| list[Callable[[IterationResult], None]]
| None = None,
) -> SolverResult:
"""
Minimise a function of one argument using Stochastic Gradient Descent (SGD).
Expand Down Expand Up @@ -65,12 +74,17 @@ def stochastic_gradient_descent(
this number of iterations is exceeded.
optimiser: The `optax` optimiser to use during the update step.
tolerance: `tolerance` used when determining if a minimum has been found.
history_logging_interval: Interval (in number of iterations) at which to log
the history of optimisation. If history_logging_interval <= 0, no
history is logged.
callbacks: A `callable` or list of `callables` that take an
`IterationResult` as their only argument, and return `None`.
These will be called at the end of each iteration of the optimisation
procedure.


Returns:
Minimising argument of `obj_fn`.
Value of `obj_fn` at the minimum.
Gradient of `obj_fn` at the minimum.
Number of iterations performed.
SolverResult: Result of the optimisation procedure.

"""
if not fn_args:
Expand All @@ -82,28 +96,49 @@ def stochastic_gradient_descent(
if not optimiser:
optimiser = optax.adam(learning_rate)

callbacks = _normalise_callbacks(callbacks)

def objective(x: npt.ArrayLike) -> npt.ArrayLike:
return obj_fn(x, *fn_args, **fn_kwargs)

def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
return convergence_criteria(x, dx) < tolerance

converged = False
value_and_grad_fn = jax.jit(jax.value_and_grad(objective))

# init state
opt_state = optimiser.init(initial_guess)
current_params = deepcopy(initial_guess)
gradient = jax.grad(objective)
converged = False
objective_value, gradient_value = value_and_grad_fn(current_params)

iter_result = IterationResult(
fn_args=current_params,
grad_val=gradient_value,
iters=0,
obj_val=objective_value,
)

for _ in range(maxiter + 1):
objective_value = objective(current_params)
gradient_value = gradient(current_params)
_update_iteration_result(
iter_result,
current_params,
gradient_value,
_,
objective_value,
history_logging_interval,
)
Comment on lines 122 to +130
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per Matt's convention from another PR, since we're using _ in the loop, we should probably use a name like current_iter or something for the loop variable. (Would add a suggestion sorry but GitHub doesn't let me suggest things for unchanged lines)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, because of the iters_used = _ line, I wasn't sure if this was a convention you were using, so I didn't want to change it without checking.


_run_callbacks(iter_result, callbacks)

if converged := is_converged(objective_value, gradient_value):
break

updates, opt_state = optimiser.update(gradient_value, opt_state)
current_params = optax.apply_updates(current_params, updates)

objective_value, gradient_value = value_and_grad_fn(current_params)

iters_used = _
reason_msg = (
f"Did not converge after {iters_used} iterations" if not converged else ""
Expand All @@ -117,4 +152,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
obj_val=objective_value,
reason=reason_msg,
successful=converged,
iter_history=iter_result.iter_history,
fn_args_history=iter_result.fn_args_history,
grad_val_history=iter_result.grad_val_history,
obj_val_history=iter_result.obj_val_history,
)
61 changes: 61 additions & 0 deletions src/causalprog/solvers/solver_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Module for callback functions for solvers."""

from collections.abc import Callable

from tqdm.auto import tqdm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this this or some other package I'm not aware of? Either way, if it's not part of the standard library, we should be depending on it in pyproject.toml.

Copy link
Collaborator Author

@samjmolyneux samjmolyneux Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a mistake.
I use it in the example notebook. Originally I defined it in the notebook along with a pip install there, but then refactored it into causalprog because I didn't want it to add confusion. I then just forgot to add it to the pyproject.toml.

I think the question is do we want to keep it in causalprog, and if so do we want it to be the default behaviour for these loops?

We would benefit from some sort of default optimisation monitoring system, just so it's easy to know if something is broke.
It can be a bit concerning waiting a few minutes for an optimiser without any feedback from the system.


from causalprog.solvers.iteration_result import IterationResult


def _normalise_callbacks(
callbacks: Callable[[IterationResult], None]
| list[Callable[[IterationResult], None]]
| None = None,
) -> list[Callable[[IterationResult], None]]:
if callbacks is None:
return []
if callable(callbacks):
return [callbacks]
if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks):
if isinstance(callbacks, Iterable) and all(callable(cb) for cb in callbacks):

I think will be more robust - we could be passed a tuple or any other ordered container and things would still be fine.

(In fact, I'd be somewhat tempted to go full Python "try and beg forgiveness" ethos and just do

    if all(callable(cb) for cb in callbacks):

because if we get to this point, and callbacks isn't an iterable, we'll get an error anyway. But what you have is probably better, since it gives the more general explanation "callbacks could be a list or just one callable")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with lists because of course all sorts of things like strings and bytes are iterables. But now I think of it, it wouldn't make much sense to have a string of callables 😂.

Would there be other weird cases that we've not considered?
I suppose in theory, someone could pass a dictionary where all the keys are callables and this would pass the checks.
Maybe we could restrict to tuples and lists or something? Or maybe it is best to just do iterables without type checks and beg forgiveness like you say. I don't know

return callbacks

msg = "Callbacks must be a callable or a sequence of callables"
raise TypeError(msg)


def _run_callbacks(
iter_result: IterationResult,
callbacks: list[Callable[[IterationResult], None]],
) -> None:
for cb in callbacks:
cb(iter_result)


def tqdm_callback(total: int) -> Callable[[IterationResult], None]:
"""
Progress bar callback using `tqdm`.
Creates a callback function that can be passed to solvers to display a progress bar
during optimization. The progress bar updates based on the number of iterations and
also displays the current objective value.
Args:
total: Total number of iterations for the progress bar.
Returns:
Callback function that updates the progress bar.
"""
bar = tqdm(total=total)
last_it = {"i": 0}

def cb(ir: IterationResult) -> None:
step = ir.iters - last_it["i"]
if step > 0:
bar.update(step)

# Show objective and grad norm
bar.set_postfix(obj=float(ir.obj_val))
last_it["i"] = ir.iters

return cb
11 changes: 10 additions & 1 deletion src/causalprog/solvers/solver_result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Container class for outputs from solver methods."""

from dataclasses import dataclass
from dataclasses import dataclass, field

import numpy.typing as npt

Expand All @@ -26,6 +26,10 @@ class SolverResult:
successful: `True` if solver converged, in which case `fn_args` is the
argument to the objective function at the solution of the problem being
solved. `False` otherwise.
iter_history: List of iteration numbers at which history was logged.
fn_args_history: List of `fn_args` at each logged iteration.
grad_val_history: List of `grad_val` at each logged iteration.
obj_val_history: List of `obj_val` at each logged iteration.

"""

Expand All @@ -36,3 +40,8 @@ class SolverResult:
obj_val: npt.ArrayLike
reason: str
successful: bool

iter_history: list[int] = field(default_factory=list)
fn_args_history: list[PyTree] = field(default_factory=list)
grad_val_history: list[PyTree] = field(default_factory=list)
obj_val_history: list[npt.ArrayLike] = field(default_factory=list)
Comment on lines +43 to +47
Copy link
Collaborator

@willGraham01 willGraham01 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just subclass here as class SolverResult(IterationResult)? We can drop the frozen=True and trust that the user isn't going to overwrite our results.

If we don't want to subclass; rather than adding all these attributes directly to the class, just add another attribute like iteration_log and assign its value as the IterationResult instance created from running the solver?

Copy link
Collaborator Author

@samjmolyneux samjmolyneux Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things I'm thinking about here:

  1. I like SolverResult being frozen. Packages like these are used largely by Jupyter notebook users. It's very easy to reassign an attribute and delete the cell you reassigned it in and then not realise that you have changed something. I think leaving SolverResult unfrozen will be a source of hard to detect problems for the end user.

  2. IterationResult is being passed by reference to user written callbacks. Meaning a callback can directly change IterationResult and thus convergence behaviour. I was unsure whether this would be a desirable trait for causalprog.
    I personally like it, because the change in convergence behaviour can be intentional and useful, such as adding a callback that clips that parameters into some feasible set. On the other, a user may accidentally reassign the params and break everything.
    Either way, I think 2 is something to raise an issue and think about later, but it certainly highlights a possible difference in direction for the two classes that might make it easier to keep them separated for now.

To summarise, I think we definitely want SolverResult to be frozen. I don't know what we want to do with IterationResult, so it seems easiest to keep it separate from SolverResult for now and not commit to anything until we need to.

But that's just my opinion. I might be overlooking something that makes the above non-issues.

Writing this out has also made me realise the need to make the history attributes mutable by the optimisation loop only and not by callbacks.

29 changes: 29 additions & 0 deletions tests/test_solvers/test_normalise_solver_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from causalprog.solvers.iteration_result import IterationResult
from causalprog.solvers.solver_callbacks import _normalise_callbacks


def test_normalise_callbacks() -> None:
"""Test that callbacks are normalised correctly."""

def callback(iter_result: IterationResult) -> None:
pass

# Test single callable
assert _normalise_callbacks(callback) == [callback]

# Test sequence of callables
assert _normalise_callbacks([callback, callback]) == [callback, callback]

# Test None
assert _normalise_callbacks(None) == []

# Test empty sequence
assert _normalise_callbacks([]) == []

# Test invalid input
with pytest.raises(
TypeError, match="Callbacks must be a callable or a sequence of callables"
):
_normalise_callbacks(42) # type: ignore[arg-type]
Comment on lines +7 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good tests. Just flagging that we're favouring the parametrised approach rather than nesting multiple cases (in the event the setup is cheap).

We also have a raises_error fixture because that pytest.raises syntax is very common in our tests 😅

So would suggest something like

def placeholder_callback(_: IterationResult) -> None:
    pass

@pytest.mark.parametrize(
    ("input", "expected"),
    [
        pytest.param([], [], id="Empty list",
        pytest.param(placeholder_callback, [placeholder_callback], id="Single callable",
        pytest.param(42, TypeError("Callbacks...."), id="..."),
        # Add other test cases as necessary
    ]
)
def test_normalise_callbacks(input, expected, raises_error) -> None:
    """Test that callbacks are normalised correctly."""

    if isinstance(expected, Exception):
        with raises_context(expected):
            _normalise_callbacks(input)
    else:
        assert _normalise_callbacks(input) == expected

(with typehints etc added as needed).

Loading