Skip to content

Commit 3fbed8d

Browse files
committed
Add tests for sgd callbacks and history logging
1 parent eaba080 commit 3fbed8d

File tree

2 files changed

+287
-0
lines changed

2 files changed

+287
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
3+
from causalprog.solvers.iteration_result import IterationResult
4+
from causalprog.solvers.solver_callbacks import _normalise_callbacks
5+
6+
7+
def test_normalise_callbacks() -> None:
8+
"""Test that callbacks are normalised correctly."""
9+
10+
def callback(iter_result: IterationResult) -> None:
11+
pass
12+
13+
# Test single callable
14+
assert _normalise_callbacks(callback) == [callback]
15+
16+
# Test sequence of callables
17+
assert _normalise_callbacks([callback, callback]) == [callback, callback]
18+
19+
# Test None
20+
assert _normalise_callbacks(None) == []
21+
22+
# Test empty sequence
23+
assert _normalise_callbacks([]) == []
24+
25+
# Test invalid input
26+
with pytest.raises(
27+
TypeError, match="Callbacks must be a callable or a sequence of callables"
28+
):
29+
_normalise_callbacks(42) # type: ignore[arg-type]

tests/test_solvers/test_sgd.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy.typing as npt
77
import pytest
88

9+
from causalprog.solvers.iteration_result import IterationResult
910
from causalprog.solvers.sgd import stochastic_gradient_descent
1011
from causalprog.utils.norms import PyTree
1112

@@ -87,3 +88,260 @@ def test_sgd(
8788
assert jax.tree_util.tree_all(
8889
jax.tree_util.tree_map(jax.numpy.allclose, result.fn_args, expected)
8990
)
91+
92+
93+
@pytest.mark.parametrize(
94+
(
95+
"history_logging_interval",
96+
"expected_iters",
97+
),
98+
[
99+
pytest.param(
100+
1,
101+
list(range(11)),
102+
id="interval=1",
103+
),
104+
pytest.param(
105+
2,
106+
list(range(0, 11, 2)),
107+
id="interval=2",
108+
),
109+
pytest.param(
110+
3,
111+
list(range(0, 11, 3)),
112+
id="interval=3",
113+
),
114+
pytest.param(
115+
0,
116+
[],
117+
id="interval=0 (no logging)",
118+
),
119+
pytest.param(
120+
-1,
121+
[],
122+
id="interval=-1 (no logging)",
123+
),
124+
],
125+
)
126+
def test_sgd_history_logging_intervals(
127+
history_logging_interval: int, expected_iters: list[int]
128+
) -> None:
129+
"""Test that history logging intervals work correctly."""
130+
131+
def obj_fn(x):
132+
return (x**2).sum()
133+
134+
initial_guess = jnp.atleast_1d(1.0)
135+
136+
result = stochastic_gradient_descent(
137+
obj_fn,
138+
initial_guess,
139+
maxiter=10,
140+
tolerance=0.0,
141+
history_logging_interval=history_logging_interval,
142+
)
143+
144+
# Check that the correct iterations were logged
145+
assert result.iter_history == expected_iters, (
146+
f"IterationResult.iter_history logged incorrectly. Expected {expected_iters}."
147+
f"Got {result.iter_history}"
148+
)
149+
150+
# Check that a correct number of fn_args, grad_val, obj_val were logged
151+
assert len(result.fn_args_history) == len(expected_iters), (
152+
"IterationResult.fn_args_history logged incorrectly."
153+
f"Expected {len(expected_iters)} entries. Got {len(result.fn_args_history)}"
154+
)
155+
assert len(result.grad_val_history) == len(expected_iters), (
156+
"IterationResult.grad_val_history logged incorrectly."
157+
f"Expected {len(expected_iters)} entries. Got {len(result.grad_val_history)}"
158+
)
159+
assert len(result.obj_val_history) == len(expected_iters), (
160+
"IterationResult.obj_val_history logged incorrectly."
161+
f"Expected {len(expected_iters)} entries. Got {len(result.obj_val_history)}"
162+
)
163+
164+
# Check that logged fn_args, grad_val, obj_val line up correctly
165+
value_and_grad_fn = jax.jit(jax.value_and_grad(obj_fn))
166+
167+
if len(expected_iters) > 0:
168+
for fn_args, obj_val, grad_val in zip(
169+
result.fn_args_history,
170+
result.obj_val_history,
171+
result.grad_val_history,
172+
strict=True,
173+
):
174+
real_obj_val, real_grad_val = value_and_grad_fn(fn_args)
175+
176+
# Check that logged obj_val and fn_args line up correctly
177+
assert real_obj_val == obj_val, (
178+
"Logged obj_val does not match obj_fn evaluated at logged fn_args."
179+
f"For fn_args {fn_args}, we expected {obj_fn(fn_args)}, got {obj_val}."
180+
)
181+
182+
# Check that logged gradient and fn_args line up correctly
183+
assert real_grad_val == grad_val, (
184+
"Logged grad_val does not match gradient of obj_fn evaluated at"
185+
f" logged fn_args. For fn_args {fn_args}, we expected"
186+
f" {jax.gradient(obj_fn)(fn_args)}, got {grad_val}."
187+
)
188+
189+
190+
@pytest.mark.parametrize(
191+
(
192+
"make_callbacks",
193+
"expected",
194+
),
195+
[
196+
(
197+
lambda cb: cb,
198+
[0, 1, 2],
199+
),
200+
(
201+
lambda cb: [cb],
202+
[0, 1, 2],
203+
),
204+
(
205+
lambda cb: [cb, cb],
206+
[0, 0, 1, 1, 2, 2],
207+
),
208+
(
209+
lambda cb: [], # noqa: ARG005
210+
[],
211+
),
212+
(
213+
lambda cb: None, # noqa: ARG005
214+
[],
215+
),
216+
],
217+
ids=[
218+
"single callable",
219+
"list of one callable",
220+
"list of two callables",
221+
"callbacks=[]",
222+
"callbacks=None",
223+
],
224+
)
225+
def test_sgd_callbacks_invocation(
226+
make_callbacks: Callable, expected: list[int]
227+
) -> None:
228+
"""Test SGD invokes callbacks correctly for all shapes of callbacks input."""
229+
230+
def obj_fn(x):
231+
return (x**2).sum()
232+
233+
calls = []
234+
235+
def callback(iter_result: IterationResult) -> None:
236+
calls.append(iter_result.iters)
237+
238+
callbacks = make_callbacks(callback)
239+
240+
initial = jnp.atleast_1d(1.0)
241+
242+
stochastic_gradient_descent(
243+
obj_fn,
244+
initial,
245+
maxiter=2,
246+
tolerance=0.0,
247+
callbacks=callbacks,
248+
)
249+
250+
assert calls == expected, (
251+
f"Callback was not called correctly, got {calls}, expected {expected}"
252+
)
253+
254+
255+
def test_sgd_invalid_callback() -> None:
256+
def obj_fn(x):
257+
return (x**2).sum()
258+
259+
initial = jnp.atleast_1d(1.0)
260+
261+
with pytest.raises(
262+
TypeError, match="Callbacks must be a callable or a sequence of callables"
263+
):
264+
stochastic_gradient_descent(
265+
obj_fn,
266+
initial,
267+
maxiter=2,
268+
tolerance=0.0,
269+
callbacks=42, # type: ignore[arg-type]
270+
)
271+
272+
273+
@pytest.mark.parametrize(
274+
"history_logging_interval", [0, 1, 2], ids=lambda v: f"hist:{v}"
275+
)
276+
@pytest.mark.parametrize(
277+
"make_callbacks",
278+
[
279+
lambda cb: cb,
280+
lambda cb: [cb],
281+
lambda cb: [cb, cb],
282+
lambda cb: [], # noqa: ARG005
283+
lambda cb: None, # noqa: ARG005
284+
],
285+
ids=["callable", "list_1", "list_2", "empty", "none"],
286+
)
287+
def test_logging_or_callbacks_affect_sgd_convergence(
288+
history_logging_interval,
289+
make_callbacks,
290+
) -> None:
291+
"""Test that logging and callbacks don't affect convergence of SGD solver."""
292+
calls = []
293+
294+
def callback(iter_result: IterationResult) -> None:
295+
calls.append(iter_result.iters)
296+
297+
callbacks = make_callbacks(callback)
298+
299+
def obj_fn(x):
300+
return (x**2).sum()
301+
302+
initial_guess = jnp.atleast_1d(1.0)
303+
304+
baseline_result = stochastic_gradient_descent(
305+
obj_fn,
306+
initial_guess,
307+
maxiter=6,
308+
tolerance=0.0,
309+
history_logging_interval=0,
310+
)
311+
312+
result = stochastic_gradient_descent(
313+
obj_fn,
314+
initial_guess,
315+
maxiter=6,
316+
tolerance=0.0,
317+
history_logging_interval=history_logging_interval,
318+
callbacks=callbacks,
319+
)
320+
321+
baseline_attributes = [
322+
baseline_result.fn_args,
323+
baseline_result.obj_val,
324+
baseline_result.grad_val,
325+
baseline_result.iters,
326+
baseline_result.successful,
327+
baseline_result.reason,
328+
]
329+
330+
result_attributes = [
331+
result.fn_args,
332+
result.obj_val,
333+
result.grad_val,
334+
result.iters,
335+
result.successful,
336+
result.reason,
337+
]
338+
339+
for baseline_attr, result_attr in zip(
340+
baseline_attributes, result_attributes, strict=True
341+
):
342+
assert baseline_attr == result_attr, (
343+
"Logging or callbacks changed the convergence behaviour of the"
344+
" solver. For history_logging_interval"
345+
f" {history_logging_interval}, callbacks {callbacks}, expected"
346+
f" {baseline_attributes}, got {result_attributes}"
347+
)

0 commit comments

Comments
 (0)