-
Notifications
You must be signed in to change notification settings - Fork 0
Add callbacks, history logging and performance improvements to sgd. #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
To sgd.py: - Added single pass JIT grad and objective calculation for big speed - boosts. - Added history logging option for easy debugging and better understanding of optimisation process. - Added callbacks functionality, allowing for user specific callbacks, e.g. early stopping, live graph etc. Added iteration_result.py to monitor current state of convergence. Use of dataclass ensures backwards compatablity of callbacks. Added solver_callbacks.py for callbacks and associated funcs. Added history attributes to solver result.
62924d7
to
3fbed8d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All-in-all I like these changes, and these are useful features that we could do with adding. Callbacks in particular should be very useful from a development/debugging perspective too,
Most of my comments relate to the design decisions for the code, considering what's in the rest of the codebase. Namely I think we can do some code recycling in places, and we tend to write our tests in a particular format (though the test cases provided are good).
Also, there are only two commits on this branch (one for codebase changes, one for tests). In general, don't be afraid to use more granular commits in your PRs (just take a look at how long the other PRs are!) - we use squash merges anyway, so everything gets condensed into a single commit on main
anyway. And it's good to be able to roll things back.
obj_val_history: list[npt.ArrayLike] = field(default_factory=list) | ||
|
||
|
||
def _update_iteration_result( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why this isn't a method of the IterationResult
class? I expected this to be something like IterationResult.update
(and it's currently written in that form too -> swap iter_result
to self
).
Related, any reason for why history_logging_interval
is an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just check history_logging_interval > 0
once at creation time, and not do it every time in the method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why this isn't a method of the
IterationResult
class? I expected this to be something likeIterationResult.update
(and it's currently written in that form too -> swapiter_result
toself
).
Nope. It should be.
Related, any reason for why
history_logging_interval
is an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just checkhistory_logging_interval > 0
once at creation time, and not do it every time in the method.
No. This is a good point.
for _ in range(maxiter + 1): | ||
objective_value = objective(current_params) | ||
gradient_value = gradient(current_params) | ||
_update_iteration_result( | ||
iter_result, | ||
current_params, | ||
gradient_value, | ||
_, | ||
objective_value, | ||
history_logging_interval, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per Matt's convention from another PR, since we're using _
in the loop, we should probably use a name like current_iter
or something for the loop variable. (Would add a suggestion sorry but GitHub doesn't let me suggest things for unchanged lines)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, because of the iters_used = _
line, I wasn't sure if this was a convention you were using, so I didn't want to change it without checking.
|
||
iter_history: list[int] = field(default_factory=list) | ||
fn_args_history: list[PyTree] = field(default_factory=list) | ||
grad_val_history: list[PyTree] = field(default_factory=list) | ||
obj_val_history: list[npt.ArrayLike] = field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just subclass here as class SolverResult(IterationResult)
? We can drop the frozen=True
and trust that the user isn't going to overwrite our results.
If we don't want to subclass; rather than adding all these attributes directly to the class, just add another attribute like iteration_log
and assign its value as the IterationResult
instance created from running the solver?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two things I'm thinking about here:
- I like
SolverResult
being frozen. Packages like these are used largely by Jupyter notebook users. It's very easy to reassign an attribute and delete the cell you reassigned it in and then not realise that you have changed something. I think leavingSolverResult
unfrozen will be a source of hard to detect problems for the end user. IterationResult
is being passed by reference to user written callbacks. Meaning a callback can directly changeIterationResult
and thus convergence behaviour. I was unsure whether this would be a desirable trait for causalprog.
I personally like it, because the change in convergence behaviour can be intentional and useful, such as adding a callback that clips that parameters into some feasible set. On the other, a user may accidentally reassign the params and break everything.
Either way, I think 2 is something to raise an issue and think about later, but it certainly highlights a possible difference in direction for the two classes that might make it easier to keep them separated for now.
To summarise, I think we definitely want SolverResult
to be frozen. I don't know what we want to do with IterationResult
, so it seems easiest to keep it separate from SolverResult
for now and not commit to anything until we need to.
But that's just my opinion. I might be overlooking something that makes the above non-issues.
Writing this out has also made me realise the need to make the history attributes mutable by the optimisation loop only and not by callbacks.
|
||
from collections.abc import Callable | ||
|
||
from tqdm.auto import tqdm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this this or some other package I'm not aware of? Either way, if it's not part of the standard library, we should be depending on it in pyproject.toml
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just a mistake.
I use it in the example notebook. Originally I defined it in the notebook along with a pip install there, but then refactored it into causalprog because I didn't want it to add confusion. I then just forgot to add it to the pyproject.toml.
I think the question is do we want to keep it in causalprog, and if so do we want it to be the default behaviour for these loops?
We would benefit from some sort of default optimisation monitoring system, just so it's easy to know if something is broke.
It can be a bit concerning waiting a few minutes for an optimiser without any feedback from the system.
return [] | ||
if callable(callbacks): | ||
return [callbacks] | ||
if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks): | |
if isinstance(callbacks, Iterable) and all(callable(cb) for cb in callbacks): |
I think will be more robust - we could be passed a tuple
or any other ordered container and things would still be fine.
(In fact, I'd be somewhat tempted to go full Python "try and beg forgiveness" ethos and just do
if all(callable(cb) for cb in callbacks):
because if we get to this point, and callbacks
isn't an iterable, we'll get an error anyway. But what you have is probably better, since it gives the more general explanation "callbacks could be a list or just one callable")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with lists because of course all sorts of things like strings and bytes are iterables. But now I think of it, it wouldn't make much sense to have a string of callables 😂.
Would there be other weird cases that we've not considered?
I suppose in theory, someone could pass a dictionary where all the keys are callables and this would pass the checks.
Maybe we could restrict to tuples and lists or something? Or maybe it is best to just do iterables without type checks and beg forgiveness like you say. I don't know
def test_normalise_callbacks() -> None: | ||
"""Test that callbacks are normalised correctly.""" | ||
|
||
def callback(iter_result: IterationResult) -> None: | ||
pass | ||
|
||
# Test single callable | ||
assert _normalise_callbacks(callback) == [callback] | ||
|
||
# Test sequence of callables | ||
assert _normalise_callbacks([callback, callback]) == [callback, callback] | ||
|
||
# Test None | ||
assert _normalise_callbacks(None) == [] | ||
|
||
# Test empty sequence | ||
assert _normalise_callbacks([]) == [] | ||
|
||
# Test invalid input | ||
with pytest.raises( | ||
TypeError, match="Callbacks must be a callable or a sequence of callables" | ||
): | ||
_normalise_callbacks(42) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good tests. Just flagging that we're favouring the parametrised approach rather than nesting multiple cases (in the event the setup is cheap).
We also have a raises_error
fixture because that pytest.raises
syntax is very common in our tests 😅
So would suggest something like
def placeholder_callback(_: IterationResult) -> None:
pass
@pytest.mark.parametrize(
("input", "expected"),
[
pytest.param([], [], id="Empty list",
pytest.param(placeholder_callback, [placeholder_callback], id="Single callable",
pytest.param(42, TypeError("Callbacks...."), id="..."),
# Add other test cases as necessary
]
)
def test_normalise_callbacks(input, expected, raises_error) -> None:
"""Test that callbacks are normalised correctly."""
if isinstance(expected, Exception):
with raises_context(expected):
_normalise_callbacks(input)
else:
assert _normalise_callbacks(input) == expected
(with typehints etc added as needed).
) | ||
|
||
|
||
@pytest.mark.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't advocate for doing this here and now, but I think we can likely condense these tests a bit.
If I'm right, we're currently checking:
- The number of iterations logged is correct
- The parameters / objective function / gradient value logged at each iteration is correct
- Callbacks are invoked correctly regardless of shape (whether they are a list / single callable / None etc).
- And catching the error case of the above (when not given callables).
- Testing that callbacks don't affect the SGD result / convergence.
The correct invocation (and its associated error catch) are the same things that we're checking in _normalise_callbacks
. As such, I'm of the opinion that we don't need to test for catching them here (since the tests for _normalise_callbacks
will flag what happens if we pass bad things in here!) - and we should just pass valid entries to sgd
's callbacks
argument. Testing these callbacks return & log the expected values however, is of course something we should still be doing!
Value logging is probably worth checking, but we can probably drop one of the "interval=2" and "interval=3" cases (the purpose of both tests is to check the logging interval is respected), and one of the "interval=0" and "interval=-1" cases (which both check something sensible happens for a nonsensical input).
This means that it's probably possible to condense these 3 tests into a single test function (with parametrisation) along the lines of "test_sgd_logging
". Where in each test we check logging, recording, and non-effect on convergence in each case. But that sounds like a lot of reorganisation, which I should probably just break out into a follow-on issue 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm right, we're currently checking:
- The number of iterations logged is correct
- The parameters / objective function / gradient value logged at each iteration is correct
- Callbacks are invoked correctly regardless of shape (whether they are a list / single callable / None etc).
- And catching the error case of the above (when not given callables).
- Testing that callbacks don't affect the SGD result / convergence.
Yes and for the last bullet point, we also testing that convergence isn't affected by combinations of history logging and callbacks. With the additional caveat that if the IterationResult
attributes are directly changed then of course the convergence will differ.
The correct invocation (and its associated error catch) are the same things that we're checking in
_normalise_callbacks
. As such, I'm of the opinion that we don't need to test for catching them here (since the tests for_normalise_callbacks
will flag what happens if we pass bad things in here!) - and we should just pass valid entries tosgd
'scallbacks
argument.
Opting to test both was intentional. My thoughts are that, 1). I would like to know _normalise_callbacks
works correctly and 2). it is implemented correctly in each solver. I think we could remove either one of test_sgd_callbacks_invoaction
or test_normalise_callbacks
. But I personally favour removing test_normalise_callbacks
, and keeping test_sgd_callbacks_invocation
because I think it's more important to know that it is implemented correctly in each solver.
Value logging is probably worth checking, but we can probably drop one of the "interval=2" and "interval=3" cases (the purpose of both tests is to check the logging interval is respected), and one of the "interval=0" and "interval=-1" cases (which both check something sensible happens for a nonsensical input).
Yeah I can remove those. My brain always just questions if there is something special about the first edge case that makes it work correctly, so I always feel the need to excessively add more!
This means that it's probably possible to condense these 3 tests into a single test function (with parametrisation) along the lines of "
test_sgd_logging
". Where in each test we check logging, recording, and non-effect on convergence in each case. But that sounds like a lot of reorganisation, which I should probably just break out into a follow-on issue 😅
Got it 👍 .
def obj_fn(x): | ||
return (x**2).sum() | ||
|
||
calls = [] | ||
|
||
def callback(iter_result: IterationResult) -> None: | ||
calls.append(iter_result.iters) | ||
|
||
callbacks = make_callbacks(callback) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing this a lot in this file, maybe some local fixtures would help?
EG:
@pytest.fixture
def obj_fn() -> Callable[[PyTree], npt.ArrayLike]:
def _inner(x):
return (x**2).sum()
return _inner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a matter of preference.
My workflow with tests is, when a test fails, I want to be able to easily jump to file and line number and be able to immediately see what's happening. For that reason I shy away from using lots of fixtures and forcing people to jump around while debugging. With the obvious exceptions being long pieces of code and computationally demanding tasks.
I tend to favour maximising readability and proximity over minimising verbosity and repetition in cases where the repeated code doesn't need to remain identical in all places.
For this reason you'll notice I similarly recalculate baseline_result
multiple times in test_logging_or_callbacks_affect_sgd_convergence
instead of taking it out into a fixture. It adds less the half a second to the tests, but reduces the reading comprehension time by much more than that.
jax.jit
, and calculating the gradient and objective in a single pass withjax.value_and_grad
, we get big speed boosts to sgd. The new time taken is approximatelyhistory_logging_interval
parameter for thestochastic_gradient_descent
function allows the user to enable or disable logging of the optimisation history. The interval determines how frequently the history is logged. This makes it easier to debug optimisations and make decisions about hyperparameters.IterationResult
which is analogous toSolverResult
. However,IterationResult
usesfrozen=True
to allow for dataclass updates each iteration. Using a dataclass ensures backward compatibility for callbacks if a new attribute is logged.callbacks
parameter for thestochastic_gradient_descent
function allows the user to set a list of callback functions as is standard in optimisation loops. In future, the callbacks can be used for early stopping or live plotting of the results. As an example, we include a usefultqdm
callback that displays a progress bar for the iterations and displays the current objective value.test_normalise_callbacks
: Tests that_normalise_callbacks
does validation and casts valid types tolist[Callable[IterationResult], None]
.test_sgd_history_logging_intervals
: Tests that the correct iterations are logged for different intervals and that the correct associated obj, fn_args and grad are too for sgd.test_callback_invocation
: Tests that sgd callbacks are called in the correct order with the correctIterationResult
.test_invalid_callback
: Tests that sgd will raise an error if given an invalid callback.test_logging_or_callbacks_affect_sgd_convergence
: Tests that various combinations of callbacks and logging intervals all result in the same convergence behaviour and thus all have the same final obj, fn_args, grad etc.