Skip to content

Commit

Permalink
Refactor DEModel.splines -> DEModel._splines (#2292)
Browse files Browse the repository at this point in the history
To be consistent with all other model components.
  • Loading branch information
dweindl authored Feb 20, 2024
1 parent 24685a3 commit f9bfefc
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def __init__(
self._expressions: list[Expression] = []
self._conservation_laws: list[ConservationLaw] = []
self._events: list[Event] = []
self.splines = []
self._splines = []
self._symboldim_funs: dict[str, Callable[[], int]] = {
"sx": self.num_states_solver,
"v": self.num_states_solver,
Expand Down Expand Up @@ -968,7 +968,7 @@ def import_from_sbml_importer(
value=spline_expr,
)
)
self.splines = si.splines
self._splines = si.splines

# get symbolic expression from SBML importers
symbols = copy.copy(si.symbols)
Expand Down Expand Up @@ -1690,15 +1690,15 @@ def _generate_symbol(self, name: str) -> None:
# placeholders for the numeric spline values.
# Need to create symbols
self._syms[name] = sp.Matrix(
[[f"spl_{isp}" for isp in range(len(self.splines))]]
[[f"spl_{isp}" for isp in range(len(self._splines))]]
)
return
elif name == "sspl":
# placeholders for spline sensitivities. Need to create symbols
self._syms[name] = sp.Matrix(
[
[f"sspl_{isp}_{ip}" for ip in range(len(self._syms["p"]))]
for isp in range(len(self.splines))
for isp in range(len(self._splines))
]
)
return
Expand Down Expand Up @@ -2050,15 +2050,15 @@ def _compute_equation(self, name: str) -> None:
elif name == "spline_values":
# force symbols
self._eqs[name] = sp.Matrix(
[y for spline in self.splines for y in spline.values_at_nodes]
[y for spline in self._splines for y in spline.values_at_nodes]
)

elif name == "spline_slopes":
# force symbols
self._eqs[name] = sp.Matrix(
[
d
for spline in self.splines
for spline in self._splines
for d in (
sp.zeros(len(spline.derivatives_at_nodes), 1)
if spline.derivatives_by_fd
Expand Down Expand Up @@ -2892,7 +2892,7 @@ def __init__(
self.model: DEModel = de_model
self.model._code_printer.known_functions.update(
splines.spline_user_functions(
self.model.splines, self._get_index("p")
self.model._splines, self._get_index("p")
)
)

Expand Down Expand Up @@ -3553,14 +3553,14 @@ def _get_function_body(
return [line for line in lines if line]

def _get_create_splines_body(self):
if not self.model.splines:
if not self.model._splines:
return [" return {};"]

ind4 = " " * 4
ind8 = " " * 8

body = ["return {"]
for ispl, spline in enumerate(self.model.splines):
for ispl, spline in enumerate(self.model._splines):
if isinstance(spline.nodes, splines.UniformGrid):
nodes = (
f"{ind8}{{{spline.nodes.start}, {spline.nodes.stop}}}, "
Expand Down Expand Up @@ -3674,7 +3674,7 @@ def _write_model_header_cpp(self) -> None:
"NEVENT": self.model.num_events(),
"NEVENT_SOLVER": self.model.num_events_solver(),
"NOBJECTIVE": "1",
"NSPL": len(self.model.splines),
"NSPL": len(self.model._splines),
"NW": len(self.model.sym("w")),
"NDWDP": len(
self.model.sparsesym(
Expand Down

0 comments on commit f9bfefc

Please sign in to comment.