Skip to content

Conversation

samjmolyneux
Copy link
Collaborator

  • By using jax.jit, and calculating the gradient and objective in a single pass with jax.value_and_grad, we get big speed boosts to sgd. The new time taken is approximately $\frac{1}{7}\text{th}$ of the original for the two normal example on my machine.

  • New history_logging_interval parameter for the stochastic_gradient_descent function allows the user to enable or disable logging of the optimisation history. The interval determines how frequently the history is logged. This makes it easier to debug optimisations and make decisions about hyperparameters.

  • To make history logging work, we add a dataclass IterationResult which is analogous to SolverResult. However, IterationResult uses frozen=True to allow for dataclass updates each iteration. Using a dataclass ensures backward compatibility for callbacks if a new attribute is logged.

  • New callbacks parameter for the stochastic_gradient_descent function allows the user to set a list of callback functions as is standard in optimisation loops. In future, the callbacks can be used for early stopping or live plotting of the results. As an example, we include a useful tqdm callback that displays a progress bar for the iterations and displays the current objective value.

  • We also add the following tests:
    • test_normalise_callbacks: Tests that _normalise_callbacks does validation and casts valid types to list[Callable[IterationResult], None] .
    • test_sgd_history_logging_intervals: Tests that the correct iterations are logged for different intervals and that the correct associated obj, fn_args and grad are too for sgd.
    • test_callback_invocation: Tests that sgd callbacks are called in the correct order with the correct IterationResult.
    • test_invalid_callback: Tests that sgd will raise an error if given an invalid callback.
    • test_logging_or_callbacks_affect_sgd_convergence: Tests that various combinations of callbacks and logging intervals all result in the same convergence behaviour and thus all have the same final obj, fn_args, grad etc.

To sgd.py:
    - Added single pass JIT grad and objective calculation for big speed
    - boosts.
    - Added history logging option for easy debugging and better understanding
    of optimisation process.
    - Added callbacks functionality, allowing for user specific callbacks,
    e.g. early stopping, live graph etc.

Added iteration_result.py to monitor current state of convergence.
Use of dataclass ensures backwards compatablity of callbacks.

Added solver_callbacks.py for callbacks and associated funcs.

Added history attributes to solver result.
@samjmolyneux samjmolyneux force-pushed the sjmolyneux/sgd-improvements branch from 62924d7 to 3fbed8d Compare October 5, 2025 21:30
@willGraham01 willGraham01 self-requested a review October 6, 2025 09:26
Copy link
Collaborator

@willGraham01 willGraham01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All-in-all I like these changes, and these are useful features that we could do with adding. Callbacks in particular should be very useful from a development/debugging perspective too,

Most of my comments relate to the design decisions for the code, considering what's in the rest of the codebase. Namely I think we can do some code recycling in places, and we tend to write our tests in a particular format (though the test cases provided are good).

Also, there are only two commits on this branch (one for codebase changes, one for tests). In general, don't be afraid to use more granular commits in your PRs (just take a look at how long the other PRs are!) - we use squash merges anyway, so everything gets condensed into a single commit on main anyway. And it's good to be able to roll things back.

obj_val_history: list[npt.ArrayLike] = field(default_factory=list)


def _update_iteration_result(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this isn't a method of the IterationResult class? I expected this to be something like IterationResult.update (and it's currently written in that form too -> swap iter_result to self).

Related, any reason for why history_logging_interval is an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just check history_logging_interval > 0 once at creation time, and not do it every time in the method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this isn't a method of the IterationResult class? I expected this to be something like IterationResult.update (and it's currently written in that form too -> swap iter_result to self).

Nope. It should be.

Related, any reason for why history_logging_interval is an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just check history_logging_interval > 0 once at creation time, and not do it every time in the method.

No. This is a good point.

Comment on lines 122 to +130
for _ in range(maxiter + 1):
objective_value = objective(current_params)
gradient_value = gradient(current_params)
_update_iteration_result(
iter_result,
current_params,
gradient_value,
_,
objective_value,
history_logging_interval,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per Matt's convention from another PR, since we're using _ in the loop, we should probably use a name like current_iter or something for the loop variable. (Would add a suggestion sorry but GitHub doesn't let me suggest things for unchanged lines)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, because of the iters_used = _ line, I wasn't sure if this was a convention you were using, so I didn't want to change it without checking.

Comment on lines +43 to +47

iter_history: list[int] = field(default_factory=list)
fn_args_history: list[PyTree] = field(default_factory=list)
grad_val_history: list[PyTree] = field(default_factory=list)
obj_val_history: list[npt.ArrayLike] = field(default_factory=list)
Copy link
Collaborator

@willGraham01 willGraham01 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just subclass here as class SolverResult(IterationResult)? We can drop the frozen=True and trust that the user isn't going to overwrite our results.

If we don't want to subclass; rather than adding all these attributes directly to the class, just add another attribute like iteration_log and assign its value as the IterationResult instance created from running the solver?

Copy link
Collaborator Author

@samjmolyneux samjmolyneux Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things I'm thinking about here:

  1. I like SolverResult being frozen. Packages like these are used largely by Jupyter notebook users. It's very easy to reassign an attribute and delete the cell you reassigned it in and then not realise that you have changed something. I think leaving SolverResult unfrozen will be a source of hard to detect problems for the end user.

  2. IterationResult is being passed by reference to user written callbacks. Meaning a callback can directly change IterationResult and thus convergence behaviour. I was unsure whether this would be a desirable trait for causalprog.
    I personally like it, because the change in convergence behaviour can be intentional and useful, such as adding a callback that clips that parameters into some feasible set. On the other, a user may accidentally reassign the params and break everything.
    Either way, I think 2 is something to raise an issue and think about later, but it certainly highlights a possible difference in direction for the two classes that might make it easier to keep them separated for now.

To summarise, I think we definitely want SolverResult to be frozen. I don't know what we want to do with IterationResult, so it seems easiest to keep it separate from SolverResult for now and not commit to anything until we need to.

But that's just my opinion. I might be overlooking something that makes the above non-issues.

Writing this out has also made me realise the need to make the history attributes mutable by the optimisation loop only and not by callbacks.


from collections.abc import Callable

from tqdm.auto import tqdm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this this or some other package I'm not aware of? Either way, if it's not part of the standard library, we should be depending on it in pyproject.toml.

Copy link
Collaborator Author

@samjmolyneux samjmolyneux Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a mistake.
I use it in the example notebook. Originally I defined it in the notebook along with a pip install there, but then refactored it into causalprog because I didn't want it to add confusion. I then just forgot to add it to the pyproject.toml.

I think the question is do we want to keep it in causalprog, and if so do we want it to be the default behaviour for these loops?

We would benefit from some sort of default optimisation monitoring system, just so it's easy to know if something is broke.
It can be a bit concerning waiting a few minutes for an optimiser without any feedback from the system.

return []
if callable(callbacks):
return [callbacks]
if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks):
if isinstance(callbacks, Iterable) and all(callable(cb) for cb in callbacks):

I think will be more robust - we could be passed a tuple or any other ordered container and things would still be fine.

(In fact, I'd be somewhat tempted to go full Python "try and beg forgiveness" ethos and just do

    if all(callable(cb) for cb in callbacks):

because if we get to this point, and callbacks isn't an iterable, we'll get an error anyway. But what you have is probably better, since it gives the more general explanation "callbacks could be a list or just one callable")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with lists because of course all sorts of things like strings and bytes are iterables. But now I think of it, it wouldn't make much sense to have a string of callables 😂.

Would there be other weird cases that we've not considered?
I suppose in theory, someone could pass a dictionary where all the keys are callables and this would pass the checks.
Maybe we could restrict to tuples and lists or something? Or maybe it is best to just do iterables without type checks and beg forgiveness like you say. I don't know

Comment on lines +7 to +29
def test_normalise_callbacks() -> None:
"""Test that callbacks are normalised correctly."""

def callback(iter_result: IterationResult) -> None:
pass

# Test single callable
assert _normalise_callbacks(callback) == [callback]

# Test sequence of callables
assert _normalise_callbacks([callback, callback]) == [callback, callback]

# Test None
assert _normalise_callbacks(None) == []

# Test empty sequence
assert _normalise_callbacks([]) == []

# Test invalid input
with pytest.raises(
TypeError, match="Callbacks must be a callable or a sequence of callables"
):
_normalise_callbacks(42) # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good tests. Just flagging that we're favouring the parametrised approach rather than nesting multiple cases (in the event the setup is cheap).

We also have a raises_error fixture because that pytest.raises syntax is very common in our tests 😅

So would suggest something like

def placeholder_callback(_: IterationResult) -> None:
    pass

@pytest.mark.parametrize(
    ("input", "expected"),
    [
        pytest.param([], [], id="Empty list",
        pytest.param(placeholder_callback, [placeholder_callback], id="Single callable",
        pytest.param(42, TypeError("Callbacks...."), id="..."),
        # Add other test cases as necessary
    ]
)
def test_normalise_callbacks(input, expected, raises_error) -> None:
    """Test that callbacks are normalised correctly."""

    if isinstance(expected, Exception):
        with raises_context(expected):
            _normalise_callbacks(input)
    else:
        assert _normalise_callbacks(input) == expected

(with typehints etc added as needed).

)


@pytest.mark.parametrize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't advocate for doing this here and now, but I think we can likely condense these tests a bit.

If I'm right, we're currently checking:

  • The number of iterations logged is correct
  • The parameters / objective function / gradient value logged at each iteration is correct
  • Callbacks are invoked correctly regardless of shape (whether they are a list / single callable / None etc).
  • And catching the error case of the above (when not given callables).
  • Testing that callbacks don't affect the SGD result / convergence.

The correct invocation (and its associated error catch) are the same things that we're checking in _normalise_callbacks. As such, I'm of the opinion that we don't need to test for catching them here (since the tests for _normalise_callbacks will flag what happens if we pass bad things in here!) - and we should just pass valid entries to sgd's callbacks argument. Testing these callbacks return & log the expected values however, is of course something we should still be doing!

Value logging is probably worth checking, but we can probably drop one of the "interval=2" and "interval=3" cases (the purpose of both tests is to check the logging interval is respected), and one of the "interval=0" and "interval=-1" cases (which both check something sensible happens for a nonsensical input).

This means that it's probably possible to condense these 3 tests into a single test function (with parametrisation) along the lines of "test_sgd_logging". Where in each test we check logging, recording, and non-effect on convergence in each case. But that sounds like a lot of reorganisation, which I should probably just break out into a follow-on issue 😅

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm right, we're currently checking:

  • The number of iterations logged is correct
  • The parameters / objective function / gradient value logged at each iteration is correct
  • Callbacks are invoked correctly regardless of shape (whether they are a list / single callable / None etc).
  • And catching the error case of the above (when not given callables).
  • Testing that callbacks don't affect the SGD result / convergence.

Yes and for the last bullet point, we also testing that convergence isn't affected by combinations of history logging and callbacks. With the additional caveat that if the IterationResult attributes are directly changed then of course the convergence will differ.

The correct invocation (and its associated error catch) are the same things that we're checking in _normalise_callbacks. As such, I'm of the opinion that we don't need to test for catching them here (since the tests for _normalise_callbacks will flag what happens if we pass bad things in here!) - and we should just pass valid entries to sgd's callbacks argument.

Opting to test both was intentional. My thoughts are that, 1). I would like to know _normalise_callbacks works correctly and 2). it is implemented correctly in each solver. I think we could remove either one of test_sgd_callbacks_invoaction or test_normalise_callbacks. But I personally favour removing test_normalise_callbacks, and keeping test_sgd_callbacks_invocation because I think it's more important to know that it is implemented correctly in each solver.

Value logging is probably worth checking, but we can probably drop one of the "interval=2" and "interval=3" cases (the purpose of both tests is to check the logging interval is respected), and one of the "interval=0" and "interval=-1" cases (which both check something sensible happens for a nonsensical input).

Yeah I can remove those. My brain always just questions if there is something special about the first edge case that makes it work correctly, so I always feel the need to excessively add more!

This means that it's probably possible to condense these 3 tests into a single test function (with parametrisation) along the lines of "test_sgd_logging". Where in each test we check logging, recording, and non-effect on convergence in each case. But that sounds like a lot of reorganisation, which I should probably just break out into a follow-on issue 😅

Got it 👍 .

Comment on lines +230 to +238
def obj_fn(x):
return (x**2).sum()

calls = []

def callback(iter_result: IterationResult) -> None:
calls.append(iter_result.iters)

callbacks = make_callbacks(callback)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing this a lot in this file, maybe some local fixtures would help?

EG:

@pytest.fixture
def obj_fn() -> Callable[[PyTree], npt.ArrayLike]:
    def _inner(x):
        return (x**2).sum()
    return _inner

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a matter of preference.

My workflow with tests is, when a test fails, I want to be able to easily jump to file and line number and be able to immediately see what's happening. For that reason I shy away from using lots of fixtures and forcing people to jump around while debugging. With the obvious exceptions being long pieces of code and computationally demanding tasks.

I tend to favour maximising readability and proximity over minimising verbosity and repetition in cases where the repeated code doesn't need to remain identical in all places.

For this reason you'll notice I similarly recalculate baseline_result multiple times in test_logging_or_callbacks_affect_sgd_convergence instead of taking it out into a fixture. It adds less the half a second to the tests, but reduces the reading comprehension time by much more than that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants