Skip to content

Add public testing function to mock sample #7761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ API
api/shape_utils
api/backends
api/misc
api/testing

------------------
Dimensionality
Expand Down
14 changes: 14 additions & 0 deletions docs/source/api/testing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
=======
Testing
=======

This submodule contains tools to help with testing PyMC code.


.. currentmodule:: pymc.testing

.. autosummary::
:toctree: generated/

mock_sample
mock_sample_setup_and_teardown
113 changes: 113 additions & 0 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytensor
import pytensor.tensor as pt

from arviz import InferenceData
from numpy import random as nr
from numpy import testing as npt
from pytensor.compile.mode import Mode
Expand Down Expand Up @@ -982,3 +983,115 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
rvs = rvs_in_graph(vars)
if rvs:
raise AssertionError(f"RV found in graph: {rvs}")


def mock_sample(draws: int = 10, **kwargs):
"""Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`.

Useful for testing models that use pm.sample without running MCMC sampling.

Examples
--------
Using mock_sample with pytest

.. note::

Use :func:`pymc.testing.mock_sample_setup_and_teardown` directly for pytest fixtures.

.. code-block:: python

import pytest

import pymc as pm
from pymc.testing import mock_sample


@pytest.fixture(scope="module")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should just provide this fixture, since this is how it will be used in most cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've run into some issue with the scope. I forget the exact rules but it seems like there is some clash when scope is different. That is my only worry so we don't need to define all the scopes. Will explore

def mock_pymc_sample():
original_sample = pm.sample
pm.sample = mock_sample

yield

pm.sample = original_sample

"""
random_seed = kwargs.get("random_seed", None)
model = kwargs.get("model", None)
draws = kwargs.get("draws", draws)
n_chains = kwargs.get("chains", 1)
idata: InferenceData = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
draws=draws,
)

idata.add_groups(
posterior=(
idata["prior"]
.isel(chain=0)
.expand_dims({"chain": range(n_chains)})
.transpose("chain", "draw", ...)
)
)
del idata["prior"]
if "prior_predictive" in idata:
del idata["prior_predictive"]
return idata


def mock_sample_setup_and_teardown():
"""Set up and tear down mocking of PyMC sampling functions for testing.

This function is designed to be used with pytest fixtures to temporarily replace
PyMC's sampling functionality with faster alternatives for testing purposes.

Effects during the fixture's active period:

* Replaces :func:`pymc.sample` with :func:`pymc.testing.mock_sample`, which uses
prior predictive sampling instead of MCMC
* Replaces distributions:
* :class:`pymc.Flat` with :class:`pymc.Normal`
* :class:`pymc.HalfFlat` with :class:`pymc.HalfNormal`
* Automatically restores all original functions and distributions after the test completes

Examples
--------
Use with `pytest` to mock actual PyMC sampling in test suite.

.. code-block:: python

# tests/conftest.py
import pytest
import pymc as pm
from pymc.testing import mock_sample_setup_and_teardown

# Register as a pytest fixture
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)


# tests/test_model.py
# Use in a test function
def test_model_inference(mock_pymc_sample):
with pm.Model() as model:
x = pm.Normal("x", 0, 1)
# This will use mock_sample instead of actual MCMC
idata = pm.sample()
# Test with the inference data...

"""
import pymc as pm

original_flat = pm.Flat
original_half_flat = pm.HalfFlat
original_sample = pm.sample

pm.sample = mock_sample
pm.Flat = pm.Normal
pm.HalfFlat = pm.HalfNormal

yield

pm.sample = original_sample
pm.Flat = original_flat
pm.HalfFlat = original_half_flat
51 changes: 50 additions & 1 deletion tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

import pytest

from pymc.testing import Domain
import pymc as pm

from pymc.testing import Domain, mock_sample, mock_sample_setup_and_teardown
from tests.models import simple_normal


@pytest.mark.parametrize(
Expand All @@ -32,3 +35,49 @@
def test_domain(values, edges, expectation):
with expectation:
Domain(values, edges=edges)


@pytest.mark.parametrize(
"args, kwargs, expected_size",
[
pytest.param((), {}, (1, 10), id="default"),
pytest.param((100,), {}, (1, 100), id="positional-draws"),
pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"),
pytest.param((100,), {"chains": 6}, (6, 100), id="chains"),
],
)
def test_mock_sample(args, kwargs, expected_size) -> None:
expected_chains, expected_draws = expected_size
_, model, _ = simple_normal(bounded_prior=True)

with model:
idata = mock_sample(*args, **kwargs)

assert "posterior" in idata
assert "observed_data" in idata
assert "prior" not in idata
assert "posterior_predictive" not in idata
assert "sample_stats" not in idata

assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws}


mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)


@pytest.fixture(scope="function")
def dummy_model() -> pm.Model:
with pm.Model() as model:
pm.Flat("flat")
pm.HalfFlat("half_flat")

return model


def test_fixture(mock_pymc_sample, dummy_model) -> None:
with dummy_model:
idata = pm.sample()

posterior = idata.posterior
assert posterior.sizes == {"chain": 1, "draw": 10}
assert (posterior["half_flat"] >= 0).all()