Skip to content

Commit

Permalink
Tweaked warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jul 1, 2024
1 parent 754dd79 commit d6d09dc
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 9 deletions.
1 change: 0 additions & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,6 @@ def diffeqsolve(
"`diffrax.diffeqsolve(..., discrete_terminating_event=...)` is deprecated "
"in favour of the more general `diffrax.diffeqsolve(..., event=...)` "
"interface. This will be removed in some future version of Diffrax.",
category=DeprecationWarning,
stacklevel=2,
)
if event is None:
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def _scale(_y0, _y1_candidate, _y_error):
# a grad API boundary as part of a larger model.)
factor = lax.stop_gradient(factor)
factor = eqxi.nondifferentiable(factor)
dt = prev_dt * factor.astype(prev_dt)
dt = prev_dt * factor.astype(jnp.result_type(prev_dt))

# E.g. we failed an implicit step, so y_error=inf, so inv_scaled_error=0,
# so factor=factormin, and we shrunk our step.
Expand Down
6 changes: 3 additions & 3 deletions test/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def event_fn(state, **kwargs):
return state.tprev > 10

event = diffrax.DiscreteTerminatingEvent(event_fn)
with pytest.warns(DeprecationWarning, match="discrete_terminating_event"):
with pytest.warns(match="discrete_terminating_event"):
sol = diffrax.diffeqsolve(
term,
solver,
Expand All @@ -51,7 +51,7 @@ def event_fn(state, **kwargs):
return state.tprev > 10

event = diffrax.DiscreteTerminatingEvent(event_fn)
with pytest.warns(DeprecationWarning, match="discrete_terminating_event"):
with pytest.warns(match="discrete_terminating_event"):
sol = diffrax.diffeqsolve(
term,
solver,
Expand Down Expand Up @@ -82,7 +82,7 @@ def event_fn(state, **kwargs):
@jax.jit
@jax.grad
def run(y0):
with pytest.warns(DeprecationWarning, match="discrete_terminating_event"):
with pytest.warns(match="discrete_terminating_event"):
sol = diffrax.diffeqsolve(
term,
solver,
Expand Down
5 changes: 1 addition & 4 deletions test/test_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ def __call__(self, t, y, args):


def test_weaklydiagonal_deprecate():
with pytest.warns(
DeprecationWarning,
match="WeaklyDiagonalControlTerm is pending deprecation",
):
with pytest.warns(match="WeaklyDiagonalControlTerm"):
_ = diffrax.WeaklyDiagonalControlTerm(
lambda t, y, args: 0.0, lambda t0, t1: jnp.array(t1 - t0)
)

0 comments on commit d6d09dc

Please sign in to comment.