diff --git a/python/tests/test_rdata.py b/python/tests/test_rdata.py index 29ea401932..ac7659f363 100644 --- a/python/tests/test_rdata.py +++ b/python/tests/test_rdata.py @@ -2,7 +2,8 @@ import amici import numpy as np import pytest -from numpy.testing import assert_array_equal +from amici.numpy import evaluate +from numpy.testing import assert_almost_equal, assert_array_equal @pytest.fixture(scope="session") @@ -39,3 +40,27 @@ def test_rdata_by_id(rdata_by_id_fixture): assert_array_equal( rdata.by_id(model.getStateIds()[1], "sx", model), rdata.sx[:, :, 1] ) + + +def test_evaluate(rdata_by_id_fixture): + # get IDs of model components + model, rdata = rdata_by_id_fixture + expr0_id = model.getExpressionIds()[0] + state1_id = model.getStateIds()[1] + observable0_id = model.getObservableIds()[0] + + # ensure `evaluate` works for atoms + expr0 = rdata.by_id(expr0_id) + assert_array_equal(expr0, evaluate(expr0_id, rdata=rdata)) + + state1 = rdata.by_id(state1_id) + assert_array_equal(state1, evaluate(state1_id, rdata=rdata)) + + observable0 = rdata.by_id(observable0_id) + assert_array_equal(observable0, evaluate(observable0_id, rdata=rdata)) + + # ensure `evaluate` works for expressions + assert_almost_equal( + expr0 + state1 * observable0, + evaluate(f"{expr0_id} + {state1_id} * {observable0_id}", rdata=rdata), + )