Skip to content

Commit 4caaaca

Browse files
committed
Do not guess order of inputs for users
1 parent f0d82f2 commit 4caaaca

File tree

8 files changed

+77
-31
lines changed

8 files changed

+77
-31
lines changed

pymc/model/core.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pytensor.compile import DeepCopyOp, Function, ProfileStats, get_mode
3636
from pytensor.compile.sharedvalue import SharedVariable
3737
from pytensor.graph.basic import Constant, Variable
38-
from pytensor.graph.traversal import ancestors, graph_inputs
38+
from pytensor.graph.traversal import ancestors, explicit_graph_inputs, graph_inputs
3939
from pytensor.tensor import as_tensor
4040
from pytensor.tensor.math import variadic_add
4141
from pytensor.tensor.random.op import RandomVariable
@@ -62,7 +62,6 @@
6262
convert_observed_data,
6363
gradient,
6464
hessian,
65-
inputvars,
6665
join_nonshared_inputs,
6766
rewrite_pregrad,
6867
)
@@ -588,6 +587,8 @@ def compile_logp(
588587
) -> PointFunc:
589588
"""Compiled log probability density function.
590589
590+
The function expects as input a dictionary with the same structure as self.initial_point()
591+
591592
Parameters
592593
----------
593594
vars : list of random variables or potential terms, optional
@@ -599,7 +600,12 @@ def compile_logp(
599600
Whether to sum all logp terms or return elemwise logp for each variable.
600601
Defaults to True.
601602
"""
602-
return self.compile_fn(self.logp(vars=vars, jacobian=jacobian, sum=sum), **compile_kwargs)
603+
compile_kwargs.setdefault("on_unused_input", "ignore")
604+
return self.compile_fn(
605+
inputs=self.value_vars,
606+
outs=self.logp(vars=vars, jacobian=jacobian, sum=sum),
607+
**compile_kwargs,
608+
)
603609

604610
def compile_dlogp(
605611
self,
@@ -609,6 +615,9 @@ def compile_dlogp(
609615
) -> PointFunc:
610616
"""Compiled log probability density gradient function.
611617
618+
The function expects as input a dictionary with the same structure as self.initial_point()
619+
620+
612621
Parameters
613622
----------
614623
vars : list of random variables or potential terms, optional
@@ -617,7 +626,12 @@ def compile_dlogp(
617626
jacobian : bool
618627
Whether to include jacobian terms in logprob graph. Defaults to True.
619628
"""
620-
return self.compile_fn(self.dlogp(vars=vars, jacobian=jacobian), **compile_kwargs)
629+
compile_kwargs.setdefault("on_unused_input", "ignore")
630+
return self.compile_fn(
631+
inputs=self.value_vars,
632+
outs=self.dlogp(vars=vars, jacobian=jacobian),
633+
**compile_kwargs,
634+
)
621635

622636
def compile_d2logp(
623637
self,
@@ -628,6 +642,8 @@ def compile_d2logp(
628642
) -> PointFunc:
629643
"""Compiled log probability density hessian function.
630644
645+
The function expects as input a dictionary with the same structure as self.initial_point()
646+
631647
Parameters
632648
----------
633649
vars : list of random variables or potential terms, optional
@@ -636,8 +652,10 @@ def compile_d2logp(
636652
jacobian : bool
637653
Whether to include jacobian terms in logprob graph. Defaults to True.
638654
"""
655+
compile_kwargs.setdefault("on_unused_input", "ignore")
639656
return self.compile_fn(
640-
self.d2logp(vars=vars, jacobian=jacobian, negate_output=negate_output),
657+
inputs=self.value_vars,
658+
outs=self.d2logp(vars=vars, jacobian=jacobian, negate_output=negate_output),
641659
**compile_kwargs,
642660
)
643661

@@ -742,7 +760,7 @@ def dlogp(
742760
dlogp graph
743761
"""
744762
if vars is None:
745-
value_vars = None
763+
value_vars = self.continuous_value_vars
746764
else:
747765
if not isinstance(vars, list | tuple):
748766
vars = [vars]
@@ -782,7 +800,7 @@ def d2logp(
782800
d²logp graph
783801
"""
784802
if vars is None:
785-
value_vars = None
803+
value_vars = self.continuous_value_vars
786804
else:
787805
if not isinstance(vars, list | tuple):
788806
vars = [vars]
@@ -1616,7 +1634,7 @@ def compile_fn(
16161634
outs : Variable or sequence of Variables
16171635
PyTensor variable or iterable of PyTensor variables.
16181636
inputs : sequence of Variables, optional
1619-
PyTensor input variables, defaults to pytensorf.inputvars(outs).
1637+
PyTensor input variables, Required if there is more than one input.
16201638
mode
16211639
PyTensor compilation mode, default=None.
16221640
point_fn : bool
@@ -1630,7 +1648,11 @@ def compile_fn(
16301648
Compiled PyTensor function
16311649
"""
16321650
if inputs is None:
1633-
inputs = inputvars(outs)
1651+
inputs = list(explicit_graph_inputs(outs))
1652+
if len(inputs) > 1:
1653+
raise ValueError(
1654+
"compile_fn requires inputs to be specified when there is more than one input."
1655+
)
16341656

16351657
with self:
16361658
fn = compile(
@@ -1793,7 +1815,7 @@ def point_logps(self, point=None, round_vals=2, **kwargs):
17931815
factor.name: np.round(np.asarray(factor_logp), round_vals)
17941816
for factor, factor_logp in zip(
17951817
factors,
1796-
self.compile_fn(factor_logps_fn, **kwargs)(point),
1818+
self.compile_fn(inputs=self.value_vars, outs=factor_logps_fn, **kwargs)(point),
17971819
)
17981820
}
17991821

@@ -2126,8 +2148,8 @@ def compile_fn(
21262148
----------
21272149
outs
21282150
PyTensor variable or iterable of PyTensor variables.
2129-
inputs
2130-
PyTensor input variables, defaults to pytensorf.inputvars(outs).
2151+
inputs, optional
2152+
PyTensor input variables. Required if there is more than one input.
21312153
mode
21322154
PyTensor compilation mode, default=None.
21332155
point_fn : bool

pymc/pytensorf.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from pytensor.graph.fg import FunctionGraph, Output
3737
from pytensor.graph.op import HasInnerGraph
38-
from pytensor.graph.traversal import graph_inputs, walk
38+
from pytensor.graph.traversal import explicit_graph_inputs, graph_inputs, walk
3939
from pytensor.scalar.basic import Cast
4040
from pytensor.scan.op import Scan
4141
from pytensor.tensor.basic import _as_tensor_variable
@@ -165,7 +165,7 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
165165
mask[mask_idx] = 1
166166
return np.ma.MaskedArray(array_data, mask)
167167

168-
if not inputvars(x) and not rvs_in_graph(x):
168+
if not len(list(explicit_graph_inputs(x))) and not rvs_in_graph(x):
169169
return x.eval(mode=_cheap_eval_mode)
170170

171171
raise TypeError(f"Data cannot be extracted from {x}")
@@ -244,15 +244,17 @@ def cont_inputs(a):
244244
"""
245245
Get the continuous inputs into PyTensor variables.
246246
247+
NOTE: No particular order is guaranteed across PyTensor versions
248+
247249
Parameters
248250
----------
249251
a: PyTensor variable
250252
251253
Returns
252254
-------
253-
r: list of tensor variables that are continuous inputs
255+
r: list of tensor variables that are continuous inputs.
254256
"""
255-
return typefilter(inputvars(a), continuous_types)
257+
return typefilter(explicit_graph_inputs(a), continuous_types)
256258

257259

258260
def floatX(X):
@@ -310,6 +312,10 @@ def gradient1(f, v):
310312
def gradient(f, vars=None):
311313
if vars is None:
312314
vars = cont_inputs(f)
315+
if len(vars) > 1:
316+
raise ValueError(
317+
"gradient requires vars to be specified when there is more than one input."
318+
)
313319

314320
if vars:
315321
return pt.concatenate([gradient1(f, v) for v in vars], axis=0)
@@ -331,6 +337,10 @@ def grad_i(i):
331337
def jacobian(f, vars=None):
332338
if vars is None:
333339
vars = cont_inputs(f)
340+
if len(vars) > 1:
341+
raise ValueError(
342+
"jacobian requires vars to be specified when there is more than one input."
343+
)
334344

335345
if vars:
336346
return pt.concatenate([jacobian1(f, v) for v in vars], axis=1)
@@ -378,6 +388,10 @@ def hess_ii(i):
378388
def hessian_diag(f, vars=None, negate_output=True):
379389
if vars is None:
380390
vars = cont_inputs(f)
391+
if len(vars) > 1:
392+
raise ValueError(
393+
"hessian_diag requires vars to be specified when there is more than one input."
394+
)
381395

382396
if vars:
383397
res = pt.concatenate([hessian_diag1(f, v) for v in vars], axis=0)
@@ -612,7 +626,7 @@ def __call__(self, input):
612626
----------
613627
input: TensorVariable
614628
"""
615-
(oldinput,) = inputvars(self.tensor)
629+
(oldinput,) = explicit_graph_inputs(self.tensor)
616630
return pytensor.clone_replace(self.tensor, {oldinput: input}, rebuild_strict=False)
617631

618632

pymc/sampling/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,7 @@ def vectorize_over_posterior(
10551055
# These variables need to be resized because they won't be resized implicitly by
10561056
# the replacement of the needed_rvs or other random variables in the graph when we
10571057
# later call vectorize_graph.
1058-
independent_rvs: list[TensorVariable] = []
1058+
independent_rvs: list[Variable] = []
10591059
for rv in [
10601060
rv
10611061
for rv in general_toposort( # type: ignore[call-overload]

pymc/tuning/scaling.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def find_hessian(point, vars=None, model=None, negate_output=True):
5454
Variables for which Hessian is to be calculated.
5555
"""
5656
model = modelcontext(model)
57+
if vars is None:
58+
vars = model.continuous_value_vars
5759
H = model.compile_d2logp(vars, negate_output=negate_output)
5860
return H(Point(point, filter_model_vars=True, model=model))
5961

@@ -70,12 +72,16 @@ def find_hessian_diag(point, vars=None, model=None, negate_output=True):
7072
Variables for which Hessian is to be calculated.
7173
"""
7274
model = modelcontext(model)
75+
if vars is None:
76+
vars = model.continuous_value_vars
7377
H = model.compile_fn(hessian_diag(model.logp(), vars, negate_output=negate_output))
7478
return H(Point(point, model=model))
7579

7680

7781
def guess_scaling(point, vars=None, model=None, scaling_bound=1e-8):
7882
model = modelcontext(model)
83+
if vars is None:
84+
vars = model.continuous_value_vars
7985
try:
8086
h = -find_hessian_diag(point, vars, model=model, negate_output=False)
8187
except NotImplementedError:

pymc/tuning/starting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def find_MAP(
185185

186186
mx0 = RaveledVars(mx0, x0.point_map_info)
187187
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
188-
unobserved_vars_values = model.compile_fn(unobserved_vars)(
188+
unobserved_vars_values = model.compile_fn(inputs=model.value_vars, outs=unobserved_vars)(
189189
DictToArrayBijection.rmap(mx0, start)
190190
)
191191
mx = {var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)}

tests/distributions/test_mixture.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -945,8 +945,8 @@ def test_vector_components(self):
945945
latent_m = Normal("latent_m", mu=mus[..., z], sigma=1e-5, shape=nd)
946946

947947
size = 100
948-
m_val = draw(m, draws=size, random_seed=998)
949-
latent_m_val = draw(latent_m, draws=size, random_seed=998 * 2)
948+
m_val = draw(m, draws=size, random_seed=997)
949+
latent_m_val = draw(latent_m, draws=size, random_seed=997 * 2)
950950
assert m_val.shape == latent_m_val.shape
951951
# Test that each element in axis = -1 comes from the same mixture
952952
# component
@@ -959,7 +959,12 @@ def test_vector_components(self):
959959

960960
# Check that mixing of values in the last axis leads to smaller logp
961961
logp_fn = model.compile_logp(vars=[m])
962-
assert logp_fn({"m": [0, 0, 0]}) > logp_fn({"m": [0, 1, 0]}) > logp_fn({"m": [0, 1, 2]})
962+
ip = model.initial_point()
963+
assert (
964+
logp_fn(ip | {"m": [0, 0, 0]})
965+
> logp_fn(ip | {"m": [0, 1, 0]})
966+
> logp_fn(ip | {"m": [0, 1, 2]})
967+
)
963968
self.logp_matches(m, latent_m, z, npop, model=model)
964969

965970
def samples_from_same_distribution(self, *args):

tests/logprob/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
from pytensor import tensor as pt
4343
from pytensor.compile import get_default_mode
4444
from pytensor.graph.basic import equal_computations
45-
from pytensor.graph.traversal import ancestors
45+
from pytensor.graph.traversal import ancestors, explicit_graph_inputs
4646
from pytensor.tensor.random.basic import NormalRV
4747
from pytensor.tensor.random.op import RandomVariable
4848

4949
import pymc as pm
5050

51-
from pymc import SymbolicRandomVariable, inputvars
51+
from pymc.distributions.distribution import SymbolicRandomVariable
5252
from pymc.distributions.transforms import Interval
5353
from pymc.logprob.abstract import MeasurableOp, valued_rv
5454
from pymc.logprob.basic import logp
@@ -231,7 +231,7 @@ def test_interdependent_transformed_rvs(self, reversed):
231231

232232
assert_no_rvs(transform_values)
233233
# Test that we haven't introduced value variables in the random graph (issue #7054)
234-
assert not inputvars(rvs)
234+
assert not any(list(explicit_graph_inputs(rvs)))
235235

236236
if reversed:
237237
transform_values = transform_values[::-1]

tests/model/test_core.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,19 +1066,18 @@ def test_model_logp(jacobian):
10661066
if not jacobian:
10671067
expected_y_logp -= np.array([0.0, 1.0])
10681068

1069-
x_logp, y_logp = m.compile_logp(sum=False, jacobian=jacobian)(
1070-
{"x": test_vals, "y_log__": test_vals}
1071-
)
1069+
test_val_dict = {"x": test_vals, "y_log__": test_vals}
1070+
x_logp, y_logp = m.compile_logp(sum=False, jacobian=jacobian)(test_val_dict)
10721071
assert np.all(np.isclose(x_logp, expected_x_logp))
10731072
assert np.all(np.isclose(y_logp, expected_y_logp))
10741073

1075-
x_logp2 = m.compile_logp(vars=[x], sum=False, jacobian=jacobian)({"x": test_vals})
1074+
x_logp2 = m.compile_logp(vars=[x], sum=False, jacobian=jacobian)(test_val_dict)
10761075
assert np.all(np.isclose(x_logp2, expected_x_logp))
10771076

1078-
y_logp2 = m.compile_logp(vars=[y], sum=False, jacobian=jacobian)({"y_log__": test_vals})
1077+
y_logp2 = m.compile_logp(vars=[y], sum=False, jacobian=jacobian)(test_val_dict)
10791078
assert np.all(np.isclose(y_logp2, expected_y_logp))
10801079

1081-
logp_sum = m.compile_logp(sum=True, jacobian=jacobian)({"x": test_vals, "y_log__": test_vals})
1080+
logp_sum = m.compile_logp(sum=True, jacobian=jacobian)(test_val_dict)
10821081
assert np.isclose(logp_sum, expected_x_logp.sum() + expected_y_logp.sum())
10831082

10841083

0 commit comments

Comments
 (0)