Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Sep 19, 2023
1 parent f16293d commit 80b7592
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion python/tests/test_rdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import amici
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from amici.numpy import evaluate
from numpy.testing import assert_almost_equal, assert_array_equal


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -39,3 +40,27 @@ def test_rdata_by_id(rdata_by_id_fixture):
assert_array_equal(
rdata.by_id(model.getStateIds()[1], "sx", model), rdata.sx[:, :, 1]
)


def test_evaluate(rdata_by_id_fixture):
# get IDs of model components
model, rdata = rdata_by_id_fixture
expr0_id = model.getExpressionIds()[0]
state1_id = model.getStateIds()[1]
observable0_id = model.getObservableIds()[0]

# ensure `evaluate` works for atoms
expr0 = rdata.by_id(expr0_id)
assert_array_equal(expr0, evaluate(expr0_id, rdata=rdata))

state1 = rdata.by_id(state1_id)
assert_array_equal(state1, evaluate(state1_id, rdata=rdata))

observable0 = rdata.by_id(observable0_id)
assert_array_equal(observable0, evaluate(observable0_id, rdata=rdata))

# ensure `evaluate` works for expressions
assert_almost_equal(
expr0 + state1 * observable0,
evaluate(f"{expr0_id} + {state1_id} * {observable0_id}", rdata=rdata),
)

0 comments on commit 80b7592

Please sign in to comment.