4040
4141from pytensor import tensor as pt
4242from pytensor .graph import RewriteDatabaseQuery
43+ from pytensor .tensor .random .type import random_generator_type
4344from scipy import stats as st
4445
4546from pymc .logprob .basic import conditional_logp , logp
@@ -352,7 +353,7 @@ def test_measurable_dimshuffle(ds_order, multivariate):
352353 np .testing .assert_array_equal (ref_logp_fn (base_test_value ), ds_logp_fn (ds_test_value ))
353354
354355
355- def test_unmeargeable_dimshuffles ():
356+ def test_unmeasurable_dimshuffles ():
356357 # Test that graphs with DimShuffles that cannot be lifted/merged fail
357358
358359 # Initial support axis is at axis=-1
@@ -372,3 +373,155 @@ def test_unmeargeable_dimshuffles():
372373 # TODO: Check that logp is correct if this type of graphs is ever supported
373374 with pytest .raises (RuntimeError , match = "could not be derived" ):
374375 conditional_logp ({w : w_vv })
376+
377+
378+ class TestMeasurableSplit :
379+ def test_univariate (self ):
380+ rng = np .random .default_rng (388 )
381+ mu = np .arange (6 )[:, None ]
382+ sigma = np .arange (5 ) + 1
383+
384+ x = pt .random .normal (mu , sigma , size = (6 , 5 ), name = "x" )
385+
386+ # axis=0
387+ x_parts = pt .split (x , splits_size = [2 , 4 ], n_splits = 2 , axis = 0 )
388+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
389+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
390+
391+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
392+ x_parts_test = [rng .normal (size = x_part .type .shape ) for x_part in x_parts_vv ]
393+ logp_x1_eval , logp_x2_eval = logp_fn (* x_parts_test )
394+ np .testing .assert_allclose (
395+ logp_x1_eval ,
396+ st .norm .logpdf (x_parts_test [0 ], mu [:2 ], sigma ),
397+ )
398+ np .testing .assert_allclose (
399+ logp_x2_eval ,
400+ st .norm .logpdf (x_parts_test [1 ], mu [2 :], sigma ),
401+ )
402+
403+ # axis=1
404+ x_parts = pt .split (x , splits_size = [2 , 1 , 2 ], n_splits = 3 , axis = 1 )
405+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
406+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
407+
408+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
409+ x_parts_test = [rng .normal (size = x_part .type .shape ) for x_part in x_parts_vv ]
410+ logp_x1_eval , logp_x2_eval , logp_x3_eval = logp_fn (* x_parts_test )
411+ np .testing .assert_allclose (
412+ logp_x1_eval ,
413+ st .norm .logpdf (x_parts_test [0 ], mu , sigma [:2 ]),
414+ )
415+ np .testing .assert_allclose (
416+ logp_x2_eval ,
417+ st .norm .logpdf (x_parts_test [1 ], mu , sigma [2 :3 ]),
418+ )
419+ np .testing .assert_allclose (
420+ logp_x3_eval ,
421+ st .norm .logpdf (x_parts_test [2 ], mu , sigma [3 :]),
422+ )
423+
424+ def test_multivariate (self ):
425+ @np .vectorize (signature = ("(n),(n)->()" ))
426+ def scipy_dirichlet_logpdf (x , alpha ):
427+ """Compute the logpdf of a Dirichlet distribution using scipy."""
428+ return st .dirichlet .logpdf (x , alpha )
429+
430+ # (3, 5) Dirichlet
431+ rng = np .random .default_rng (426 )
432+ rng_pt = random_generator_type ("rng" )
433+ alpha = np .linspace (1 , 10 , 5 ) * np .array ([1 , 10 , 100 ])[:, None ]
434+ x = pt .random .dirichlet (alpha , rng = rng_pt )
435+
436+ # axis=-2 (i.e., 0, - batch dimension)
437+ x_parts = pt .split (x , splits_size = [2 , 1 ], n_splits = 2 , axis = - 2 )
438+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
439+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
440+ assert logp_parts [0 ].type .shape == (2 ,)
441+ assert logp_parts [1 ].type .shape == (1 ,)
442+
443+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
444+ x_parts_test = pytensor .function ([rng_pt ], x_parts )(rng )
445+ logp_x1_eval , logp_x2_eval = logp_fn (* x_parts_test )
446+ np .testing .assert_allclose (
447+ logp_x1_eval ,
448+ scipy_dirichlet_logpdf (x_parts_test [0 ], alpha [:2 ]),
449+ )
450+ np .testing .assert_allclose (
451+ logp_x2_eval ,
452+ scipy_dirichlet_logpdf (x_parts_test [1 ], alpha [2 :]),
453+ )
454+
455+ # axis=-1 (i.e., 1, - support dimension)
456+ x_parts = pt .split (x , splits_size = [2 , 3 ], n_splits = 2 , axis = - 1 )
457+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
458+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
459+
460+ assert logp_parts [0 ].type .shape == (3 ,)
461+ assert logp_parts [1 ].type .shape == (3 ,)
462+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
463+
464+ x_parts_test = pytensor .function ([rng_pt ], x_parts )(rng )
465+ logp_x1_eval , logp_x2_eval = logp_fn (* x_parts_test )
466+ np .testing .assert_allclose (logp_x1_eval * 3 , logp_x2_eval * 2 )
467+ logp_total = logp_x1_eval + logp_x2_eval
468+ np .testing .assert_allclose (
469+ logp_total ,
470+ scipy_dirichlet_logpdf (np .concatenate (x_parts_test , axis = 1 ), alpha ),
471+ )
472+
473+ @pytest .mark .xfail (
474+ reason = "Rewrite from partial split to split on subtensor not implemented yet"
475+ )
476+ def test_not_all_splits_used (self ):
477+ x = pt .random .normal (mu = pt .arange (6 ), name = "x" )
478+ x_parts = pt .split (x , splits_size = [2 , 2 , 2 ], n_splits = 3 , axis = 0 )[
479+ ::2
480+ ] # Only use first two splits
481+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
482+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
483+ assert len (logp_parts ) == 2
484+
485+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
486+ x_parts_test = [x_part .eval () for x_part in x_parts_vv ]
487+ logp_x1_eval , logp_x2_eval = logp_fn (* x_parts_test )
488+ np .testing .assert_allclose (
489+ logp_x1_eval ,
490+ st .norm .logpdf (x_parts_test [0 ], loc = [0 , 1 ]),
491+ )
492+ np .testing .assert_allclose (
493+ logp_x2_eval ,
494+ st .norm .logpdf (x_parts_test [1 ], loc = [4 , 5 ]),
495+ )
496+
497+ def test_not_all_splits_used_core_dim (self ):
498+ # TODO: We could support this for univariate/batch dimensions by rewriting as
499+ # split(x, splits_size=[2, 2, 2], n_splits=3, axis=1)[:2] -> split(x[:-2], splits_size=[2, 2], n_splits=2, axis=1)
500+ # And letting logp infer the probability of x[:-2]
501+ x = pt .random .dirichlet (alphas = pt .ones (6 ), name = "x" )
502+ x_parts = pt .split (x , splits_size = [2 , 2 , 2 ], n_splits = 3 , axis = 0 )[
503+ :2
504+ ] # Only use first two splits
505+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
506+
507+ with pytest .raises (
508+ ValueError ,
509+ match = "Split logp requires the number of values to match the number of splits" ,
510+ ):
511+ conditional_logp (dict (zip (x_parts , x_parts_vv )))
512+
513+ @pytest .mark .xfail (reason = "Rewrite from subtensor to split not implemented yet" )
514+ def test_subtensor_converted_to_splits (self ):
515+ rng = np .random .default_rng (388 )
516+ x = pt .random .normal (mu = pt .arange (5 ), name = "x" )
517+
518+ x_parts = [x [:2 ], x [2 :3 ], x [3 :]]
519+ x_parts_vv = [x_part .clone () for x_part in x_parts ]
520+ logp_parts = list (conditional_logp (dict (zip (x_parts , x_parts_vv ))).values ())
521+ assert len (logp_parts ) == 3
522+ logp_fn = pytensor .function (x_parts_vv , logp_parts )
523+ x_parts_test = [rng .normal (size = x_part .type .shape ) for x_part in x_parts_vv ]
524+ logp_x1_eval , logp_x2_eval , logp_x3_eval = logp_fn (* x_parts_test )
525+ np .testing .assert_allclose (logp_x1_eval , st .norm .logpdf (x_parts_test [0 ], loc = [0 , 1 ]))
526+ np .testing .assert_allclose (logp_x2_eval , st .norm .logpdf (x_parts_test [1 ], loc = [2 ]))
527+ np .testing .assert_allclose (logp_x3_eval , st .norm .logpdf (x_parts_test [2 ], loc = [3 , 4 ]))
0 commit comments