Skip to content

Make time-based heaviside and modulo inputs explicit t_eval points with IDAKLUSolver #4994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

## Bug fixes

- Fixed a bug using a time-varying input with heaviside or modulo functions using the `IDAKLUSolver`. ([#4994](https://github.com/pybamm-team/PyBaMM/pull/4994))
- Fixed a bug in the `QuickPlot` which would return empty values for 1D variables at the beginning and end of a timespan. ([#4991](https://github.com/pybamm-team/PyBaMM/pull/4991))
- Fixed a bug in the `Exponential1DSubMesh` where the mesh was not being created correctly for non-zero minimum values. ([#4989](https://github.com/pybamm-team/PyBaMM/pull/4989))

## Breaking changes

- Remove sensitivity functionality for Casadi and Scipy solvers, only `pybamm.IDAKLU` solver can calculate sensitivities. ([#4975](https://github.com/pybamm-team/PyBaMM/pull/4975))
- Remove sensitivity functionality for Casadi and Scipy solvers, only `pybamm.IDAKLUSolver` can calculate sensitivities. ([#4975](https://github.com/pybamm-team/PyBaMM/pull/4975))

# [v25.4.2](https://github.com/pybamm-team/PyBaMM/tree/v25.4.2) - 2025-04-17

Expand Down
8 changes: 7 additions & 1 deletion src/pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import timedelta
import pybamm.telemetry
from pybamm.util import import_optional_dependency

from copy import copy
from pybamm.expression_tree.operations.serialise import Serialise


Expand Down Expand Up @@ -445,6 +445,9 @@ def solve(
"""
pybamm.telemetry.capture("simulation-solved")

# Copy t_eval to avoid modifying the original
t_eval = copy(t_eval)

# Setup
if solver is None:
solver = self._solver
Expand Down Expand Up @@ -1026,6 +1029,9 @@ def step(
Additional key-word arguments passed to `solver.solve`.
See :meth:`pybamm.BaseSolver.step`.
"""
# Copy t_eval to avoid modifying the original
t_eval = copy(t_eval)

if self.operating_mode in ["without experiment", "drive cycle"]:
self.build()

Expand Down
131 changes: 96 additions & 35 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,46 +451,107 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing):
# discontinuity events if these exist.
# Note: only checks for the case of t < X, t <= X, X < t, or X <= t,
# but also accounts for the fact that t might be dimensional
# Only do this for DAE models as ODE models can deal with discontinuities
# fine

if len(model.algebraic) > 0:
for symbol in itertools.chain(
model.concatenated_rhs.pre_order(),
model.concatenated_algebraic.pre_order(),
):
if isinstance(symbol, _Heaviside):
if symbol.right == pybamm.t:
expr = symbol.left
elif symbol.left == pybamm.t:
expr = symbol.right
else:
# Heaviside function does not contain pybamm.t as an argument.
# Do not create an event
continue # pragma: no cover

model.events.append(
pybamm.Event(
str(symbol),
expr,
pybamm.EventType.DISCONTINUITY,
)
t0 = np.min(t_eval)
tf = np.max(t_eval)

def supports_t_eval_discontinuities(expr):
# Only IDAKLUSolver supports discontinuities represented by t_eval
return (
(t_eval is not None)
and isinstance(self, pybamm.IDAKLUSolver)
and expr.is_constant()
)

def append_t_eval(t):
if t0 <= t <= tf and t not in t_eval:
# Insert t in the correct position to maintain sorted order
idx = np.searchsorted(t_eval, t)
t_eval.insert(idx, t)

def heaviside_event(symbol, expr):
model.events.append(
pybamm.Event(
str(symbol),
expr,
pybamm.EventType.DISCONTINUITY,
)
)

def heaviside_t_eval(symbol, expr):
value = expr.evaluate()
append_t_eval(value)

if isinstance(symbol, pybamm.EqualHeaviside):
if symbol.left == pybamm.t:
# t <= x
# Stop at t = x and right after t = x
append_t_eval(np.nextafter(value, np.inf))
else:
# t >= x
# Stop at t = x and right before t = x
append_t_eval(np.nextafter(value, -np.inf))
elif isinstance(symbol, pybamm.NotEqualHeaviside):
if symbol.left == pybamm.t:
# t < x
# Stop at t = x and right before t = x
append_t_eval(np.nextafter(value, -np.inf))
else:
# t > x
# Stop at t = x and right after t = x
append_t_eval(np.nextafter(value, np.inf))
else:
raise ValueError(
f"Unknown heaviside function: {symbol}"
) # pragma: no cover

def modulo_event(symbol, expr, num_events):
for i in np.arange(num_events):
model.events.append(
pybamm.Event(
str(symbol),
expr * pybamm.Scalar(i + 1),
pybamm.EventType.DISCONTINUITY,
)
)

elif isinstance(symbol, pybamm.Modulo) and symbol.left == pybamm.t:
def modulo_t_eval(symbol, expr, num_events):
value = expr.evaluate()
for i in np.arange(num_events):
t = value * (i + 1)
# Stop right before t and at t
append_t_eval(np.nextafter(t, -np.inf))
append_t_eval(t)

for symbol in itertools.chain(
model.concatenated_rhs.pre_order(),
model.concatenated_algebraic.pre_order(),
):
if isinstance(symbol, _Heaviside):
if symbol.right == pybamm.t:
expr = symbol.left
elif symbol.left == pybamm.t:
expr = symbol.right
num_events = 200 if (t_eval is None) else (t_eval[-1] // expr.value)

for i in np.arange(num_events):
model.events.append(
pybamm.Event(
str(symbol),
expr * pybamm.Scalar(i + 1),
pybamm.EventType.DISCONTINUITY,
)
)
else:
continue
# Heaviside function does not contain pybamm.t as an argument.
# Do not create an event
continue # pragma: no cover

if supports_t_eval_discontinuities(expr):
heaviside_t_eval(symbol, expr)
else:
heaviside_event(symbol, expr)

elif isinstance(symbol, pybamm.Modulo) and symbol.left == pybamm.t:
expr = symbol.right
num_events = 200 if (t_eval is None) else (tf // expr.value)

if supports_t_eval_discontinuities(expr):
modulo_t_eval(symbol, expr, num_events)
else:
modulo_event(symbol, expr, num_events)
else:
continue

casadi_switch_events = []
terminate_events = []
Expand Down
96 changes: 86 additions & 10 deletions tests/unit/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,27 +634,103 @@ def test_drive_cycle_interpolant(self):
i += 1
assert i < len(sim.solution.t)

def test_discontinuous_current(self):
# Test with an ODE and DAE model
@pytest.mark.parametrize(
"model", [pybamm.lithium_ion.SPM(), pybamm.lithium_ion.DFN()]
)
def test_heaviside_current(self, model):
def car_current(t):
current = (
1 * (t >= 0) * (t <= 1000)
- 0.5 * (1000 < t) * (t <= 2000)
1 * (t <= 1000)
- 0.5 * (1000 < t) * (t < 1500)
+ 0.5 * (2000 < t)
+ 5 * (t >= 3601)
)
return current

model = pybamm.lithium_ion.DFN()
def prevfloat(t):
return np.nextafter(np.float64(t), -np.inf)

def nextfloat(t):
return np.nextafter(np.float64(t), np.inf)

t_eval = [0.0, 3600.0]

t_nodes = np.array(
[
0.0, # t_eval[0]
1000.0, # t <= 1000
nextfloat(1000.0), # t <= 1000
prevfloat(1500.0), # t < 1500
1500.0, # t < 1500
2000.0, # 2000 < t
nextfloat(2000.0), # 2000 < t
3600.0, # t_eval[-1]
]
)

param = model.default_parameter_values
param["Current function [A]"] = car_current

sim = pybamm.Simulation(
model, parameter_values=param, solver=pybamm.CasadiSolver(mode="fast")
sim = pybamm.Simulation(model, parameter_values=param)

# Set t_interp to t_eval to only return the breakpoints
sol = sim.solve(t_eval, t_interp=t_eval)

np.testing.assert_array_equal(sol.t, t_nodes)
# Make sure t_eval is not modified
assert t_eval == [0.0, 3600.0]

current = sim.solution["Current [A]"]

for t_node in t_nodes:
assert current(t_node) == pytest.approx(car_current(t_node))

# Test with an ODE and DAE model
@pytest.mark.parametrize(
"model", [pybamm.lithium_ion.SPM(), pybamm.lithium_ion.DFN()]
)
def test_modulo_current(self, model):
dt = 1.0

def sawtooth_current(t):
return t % dt

def prevfloat(t):
return np.nextafter(np.float64(t), -np.inf)

t_eval = [0.0, 10.5]

t_nodes = np.arange(0.0, 10.5 + dt, dt)
t_nodes = np.concatenate(
[
t_nodes,
prevfloat(t_nodes),
t_eval,
]
)
sim.solve([0, 3600])

# Filter out all points not within t_eval
t_nodes = t_nodes[(t_nodes >= t_eval[0]) & (t_nodes <= t_eval[1])]

t_nodes = np.sort(np.unique(t_nodes))

param = model.default_parameter_values
param["Current function [A]"] = sawtooth_current

sim = pybamm.Simulation(model, parameter_values=param)

# Set t_interp to t_eval to only return the breakpoints
sol = sim.solve(t_eval, t_interp=t_eval)

np.testing.assert_array_equal(sol.t, t_nodes)
# Make sure t_eval is not modified
assert t_eval == [0.0, 10.5]

current = sim.solution["Current [A]"]
assert current(0) == 1
assert current(1500) == -0.5
assert current(3000) == 0.5

for t_node in t_nodes:
assert current(t_node) == pytest.approx(sawtooth_current(t_node))

def test_t_eval(self):
model = pybamm.lithium_ion.SPM()
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,25 @@ def test_solver_interpolation_warning(self):
match=f"Explicit interpolation times not implemented for {solver.name}",
):
solver.solve(model, t_eval, t_interp=t_interp)

def test_discontinuous_current(self):
def car_current(t):
current = (
1 * (t >= 0) * (t <= 1000)
- 0.5 * (1000 < t) * (t <= 2000)
+ 0.5 * (2000 < t)
)
return current

model = pybamm.lithium_ion.SPM()
param = model.default_parameter_values
param["Current function [A]"] = car_current

sim = pybamm.Simulation(
model, parameter_values=param, solver=pybamm.CasadiSolver(mode="fast")
)
sim.solve([0, 3600])
current = sim.solution["Current [A]"]
assert current(0) == 1
assert current(1500) == -0.5
assert current(3000) == 0.5
Loading