diff --git a/docs/source/api.rst b/docs/source/api.rst index a82da9bc99..d80c0984ff 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -22,6 +22,7 @@ API api/shape_utils api/backends api/misc + api/testing ------------------ Dimensionality diff --git a/docs/source/api/testing.rst b/docs/source/api/testing.rst new file mode 100644 index 0000000000..10e01a5cf2 --- /dev/null +++ b/docs/source/api/testing.rst @@ -0,0 +1,14 @@ +======= +Testing +======= + +This submodule contains tools to help with testing PyMC code. + + +.. currentmodule:: pymc.testing + +.. autosummary:: + :toctree: generated/ + + mock_sample + mock_sample_setup_and_teardown diff --git a/pymc/testing.py b/pymc/testing.py index 7ef6751892..a5fdc28327 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -22,6 +22,7 @@ import pytensor import pytensor.tensor as pt +from arviz import InferenceData from numpy import random as nr from numpy import testing as npt from pytensor.compile.mode import Mode @@ -982,3 +983,115 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None: rvs = rvs_in_graph(vars) if rvs: raise AssertionError(f"RV found in graph: {rvs}") + + +def mock_sample(draws: int = 10, **kwargs): + """Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`. + + Useful for testing models that use pm.sample without running MCMC sampling. + + Examples + -------- + Using mock_sample with pytest + + .. note:: + + Use :func:`pymc.testing.mock_sample_setup_and_teardown` directly for pytest fixtures. + + .. code-block:: python + + import pytest + + import pymc as pm + from pymc.testing import mock_sample + + + @pytest.fixture(scope="module") + def mock_pymc_sample(): + original_sample = pm.sample + pm.sample = mock_sample + + yield + + pm.sample = original_sample + + """ + random_seed = kwargs.get("random_seed", None) + model = kwargs.get("model", None) + draws = kwargs.get("draws", draws) + n_chains = kwargs.get("chains", 1) + idata: InferenceData = pm.sample_prior_predictive( + model=model, + random_seed=random_seed, + draws=draws, + ) + + idata.add_groups( + posterior=( + idata["prior"] + .isel(chain=0) + .expand_dims({"chain": range(n_chains)}) + .transpose("chain", "draw", ...) + ) + ) + del idata["prior"] + if "prior_predictive" in idata: + del idata["prior_predictive"] + return idata + + +def mock_sample_setup_and_teardown(): + """Set up and tear down mocking of PyMC sampling functions for testing. + + This function is designed to be used with pytest fixtures to temporarily replace + PyMC's sampling functionality with faster alternatives for testing purposes. + + Effects during the fixture's active period: + + * Replaces :func:`pymc.sample` with :func:`pymc.testing.mock_sample`, which uses + prior predictive sampling instead of MCMC + * Replaces distributions: + * :class:`pymc.Flat` with :class:`pymc.Normal` + * :class:`pymc.HalfFlat` with :class:`pymc.HalfNormal` + * Automatically restores all original functions and distributions after the test completes + + Examples + -------- + Use with `pytest` to mock actual PyMC sampling in test suite. + + .. code-block:: python + + # tests/conftest.py + import pytest + import pymc as pm + from pymc.testing import mock_sample_setup_and_teardown + + # Register as a pytest fixture + mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) + + + # tests/test_model.py + # Use in a test function + def test_model_inference(mock_pymc_sample): + with pm.Model() as model: + x = pm.Normal("x", 0, 1) + # This will use mock_sample instead of actual MCMC + idata = pm.sample() + # Test with the inference data... + + """ + import pymc as pm + + original_flat = pm.Flat + original_half_flat = pm.HalfFlat + original_sample = pm.sample + + pm.sample = mock_sample + pm.Flat = pm.Normal + pm.HalfFlat = pm.HalfNormal + + yield + + pm.sample = original_sample + pm.Flat = original_flat + pm.HalfFlat = original_half_flat diff --git a/tests/test_testing.py b/tests/test_testing.py index c8caf063c2..105e2f6209 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -15,7 +15,10 @@ import pytest -from pymc.testing import Domain +import pymc as pm + +from pymc.testing import Domain, mock_sample, mock_sample_setup_and_teardown +from tests.models import simple_normal @pytest.mark.parametrize( @@ -32,3 +35,49 @@ def test_domain(values, edges, expectation): with expectation: Domain(values, edges=edges) + + +@pytest.mark.parametrize( + "args, kwargs, expected_size", + [ + pytest.param((), {}, (1, 10), id="default"), + pytest.param((100,), {}, (1, 100), id="positional-draws"), + pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"), + pytest.param((100,), {"chains": 6}, (6, 100), id="chains"), + ], +) +def test_mock_sample(args, kwargs, expected_size) -> None: + expected_chains, expected_draws = expected_size + _, model, _ = simple_normal(bounded_prior=True) + + with model: + idata = mock_sample(*args, **kwargs) + + assert "posterior" in idata + assert "observed_data" in idata + assert "prior" not in idata + assert "posterior_predictive" not in idata + assert "sample_stats" not in idata + + assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws} + + +mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) + + +@pytest.fixture(scope="function") +def dummy_model() -> pm.Model: + with pm.Model() as model: + pm.Flat("flat") + pm.HalfFlat("half_flat") + + return model + + +def test_fixture(mock_pymc_sample, dummy_model) -> None: + with dummy_model: + idata = pm.sample() + + posterior = idata.posterior + assert posterior.sizes == {"chain": 1, "draw": 10} + assert (posterior["half_flat"] >= 0).all()