diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index 864825910df..be71a358c96 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -20,7 +20,8 @@ from pytensor.graph import Apply, Op from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable -from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.basic import logp class MinibatchRandomVariable(MeasurableOp, Op): @@ -99,4 +100,4 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor def minibatch_rv_logprob(op, values, *inputs, **kwargs): [value] = values rv, *total_size = inputs - return _logprob_helper(rv, value, **kwargs) * get_scaling(total_size, value.shape) + return logp(rv, value, **kwargs) * get_scaling(total_size, value.shape) diff --git a/tests/variational/test_minibatch_rv.py b/tests/variational/test_minibatch_rv.py index 6ef0c8dd707..6f3e715af7e 100644 --- a/tests/variational/test_minibatch_rv.py +++ b/tests/variational/test_minibatch_rv.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np import pytensor +import pytensor.tensor as pt import pytest from scipy import stats as st @@ -186,3 +187,12 @@ def test_minibatch_parameter_and_value(self): with m: pm.set_data({"AD": rng.normal(size=1000)}) assert logp_fn(ip) != logp_fn(ip) + + def test_derived_rv(self): + """Test we can obtain a minibatch logp out of a derived RV.""" + dist = pt.clip(pm.Normal.dist(0, 1, size=(1,)), -1, 1) + mb_dist = create_minibatch_rv(dist, total_size=(2,)) + np.testing.assert_allclose( + pm.logp(mb_dist, -1).eval(), + pm.logp(dist, -1).eval() * 2, + )