Skip to content

Commit

Permalink
Add diffing against IndexedBase instances
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran committed Oct 14, 2024
1 parent 86e470d commit 02051ba
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,11 @@ def differentiate2c(
If the result coincides with one of the vars, or the LHS of one of
the prev_expressions, then it is simplified to this expression.
Note that, in order to differentiate against indexed variables (such as
``x[0]``), you must pass an instance of ``sympy.Indexed`` to
``dependent_var`` (_not_ an instance of ``sympy.IndexedBase``), as well as
an instance of ``sympy.IndexedBase`` to ``vars``.
Some simple examples of use:
- ``nmodl.ode.differentiate2c ("a*x", "x", {"a"}) == "a"``
Expand All @@ -619,7 +624,7 @@ def differentiate2c(
# every symbol (a.k.a variable) that SymPy
# is going to manipulate needs to be declared
# explicitly
x = sp.symbols(dependent_var, real=True)
x = make_symbol(dependent_var)
vars = set(vars)
vars.discard(dependent_var)
# declare all other supplied variables
Expand Down
13 changes: 13 additions & 0 deletions test/unit/ode/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,19 @@ def test_differentiate2c():
{sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])},
)

# make sure we can diff against indexed vars as well
var = sp.IndexedBase("x", shape=[1])

assert _equivalent(
differentiate2c(
"a * x[0]",
var[0],
{"a", var},
),
"a",
{"a"},
)

result = differentiate2c(
"-f(x)",
"x",
Expand Down

0 comments on commit 02051ba

Please sign in to comment.