@@ -1471,11 +1471,15 @@ def test_with_lkjcorr_matrix(
14711471 prior = pm .sample_prior_predictive (draws = 10 , return_inferencedata = False )
14721472
14731473 assert prior ["corr_mat" ].shape == (10 , 3 , 3 ) # square
1474- assert np .allclose (prior ["corr_mat" ][:, [0 , 1 , 2 ], [0 , 1 , 2 ]], 1.0 ) # 1.0 on diagonal
14751474 assert (prior ["corr_mat" ] == prior ["corr_mat" ].transpose (0 , 2 , 1 )).all () # symmetric
1476- assert (
1477- prior ["corr_mat" ].max () <= 1.0 and prior ["corr_mat" ].min () >= - 1.0
1478- ) # constrained between -1 and 1
1475+
1476+ np .testing .assert_allclose (
1477+ prior ["corr_mat" ][:, [0 , 1 , 2 ], [0 , 1 , 2 ]], 1.0
1478+ ) # 1.0 on diagonal
1479+
1480+ # constrained between -1 and 1
1481+ assert prior ["corr_mat" ].max () <= (1.0 + 1e-12 )
1482+ assert prior ["corr_mat" ].min () >= (- 1.0 - 1e-12 )
14791483
14801484 def test_issue_3758 (self ):
14811485 np .random .seed (42 )
@@ -2172,8 +2176,6 @@ class TestLKJCorr(BaseTestDistributionRandom):
21722176 ]
21732177
21742178 def check_draws_match_expected (self ):
2175- from pymc .distributions import CustomDist
2176-
21772179 def ref_rand (size , n , eta ):
21782180 shape = int (n * (n - 1 ) // 2 )
21792181 beta = eta - 1 + n / 2
@@ -2182,16 +2184,9 @@ def ref_rand(size, n, eta):
21822184
21832185 # If passed as a domain, continuous_random_tester would make `n` a shared variable
21842186 # But this RV needs it to be constant in order to define the inner graph
2185- def lkj_corr_tril (n , eta , shape = None ):
2186- tril_idx = pt .tril_indices (n )
2187- return _LKJCorr .dist (n = n , eta = eta , shape = shape )[..., tril_idx [0 ], tril_idx [1 ]]
2188-
2189- def SlicedLKJ (name , n , eta , * args , shape = None , ** kwargs ):
2190- return CustomDist (name , n , eta , dist = lkj_corr_tril , shape = shape )
2191-
21922187 for n in (2 , 10 , 50 ):
21932188 continuous_random_tester (
2194- SlicedLKJ ,
2189+ _LKJCorr ,
21952190 {
21962191 "eta" : Domain ([1.0 , 10.0 , 100.0 ], edges = (None , None )),
21972192 },
@@ -2204,7 +2199,7 @@ def SlicedLKJ(name, n, eta, *args, shape=None, **kwargs):
22042199@pytest .mark .parametrize ("shape" , [(2 , 2 ), (3 , 2 , 2 )], ids = ["no_batch" , "with_batch" ])
22052200def test_LKJCorr_default_transform (shape ):
22062201 with pm .Model () as m :
2207- x = pm .LKJCorr ("x" , n = 2 , eta = 1 , shape = shape , return_matrix = False )
2202+ x = pm .LKJCorr ("x" , n = 2 , eta = 1 , shape = shape )
22082203 assert isinstance (m .rvs_to_transforms [x ], CholeskyCorrTransform )
22092204 assert m .logp (sum = False )[0 ].type .shape == shape [:- 2 ]
22102205
0 commit comments