You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First of all, thanks for the package and all your hard work!
I think I've encountered a couple of issues/bugs in the computation of the R-hat statistic.
First is just a typo I think. In lines 139-140 the chain means and variances are flattened by the list comprehension, where I think something like:
_means=np.vstack(means)
_vars=np.vstack(vars)
will keep the structure of the chains, so that np.var(means, ddof=1, axis=0) and np.mean(_vars, axis=0) will give the between-chain and within-chain variance, respectively, across all parameters (right now they end up as scalars, since the _means and _vars are flat).
This is similar in spirit to Issue "mk" autocorrelation calculation #22 , but with a more significant effect here: on lines 120-121 where each split is reshaped to (-1, ndim), samples from all walkers are collapsed into each split, homogenizing them and leading to unrealistically low R-hat values. In my case I had ~28 walkers, many of which were stuck in well-separated modes and barely mixing at all, but nonetheless had a quite low split R-hat as recorded by the callback because each of the two splits had samples from all 28 walkers, making them statistically similar.
Since this is an ensemble method I had to spend some time convincing myself, but I really think R-hat should be computed across all (possibly split) walkers, rather than by grouping them together. I can share my trace plots and make a case for this in more detail if it's helpful. With change (1) above fixing this would just be a matter of removing the reshape operation. Then nsplits would determine how many splits are made within each walker.
Thanks again, and let me know what your thoughts are. I'm happy to help implement these changes, too.
The text was updated successfully, but these errors were encountered:
First of all, thanks for the package and all your hard work!
I think I've encountered a couple of issues/bugs in the computation of the R-hat statistic.
will keep the structure of the chains, so that
np.var(means, ddof=1, axis=0)
andnp.mean(_vars, axis=0)
will give the between-chain and within-chain variance, respectively, across all parameters (right now they end up as scalars, since the_means
and_vars
are flat).This is similar in spirit to Issue "mk" autocorrelation calculation #22 , but with a more significant effect here: on lines 120-121 where each split is reshaped to
(-1, ndim)
, samples from all walkers are collapsed into each split, homogenizing them and leading to unrealistically low R-hat values. In my case I had ~28 walkers, many of which were stuck in well-separated modes and barely mixing at all, but nonetheless had a quite low split R-hat as recorded by the callback because each of the two splits had samples from all 28 walkers, making them statistically similar.Since this is an ensemble method I had to spend some time convincing myself, but I really think R-hat should be computed across all (possibly split) walkers, rather than by grouping them together. I can share my trace plots and make a case for this in more detail if it's helpful. With change (1) above fixing this would just be a matter of removing the reshape operation. Then
nsplits
would determine how many splits are made within each walker.Thanks again, and let me know what your thoughts are. I'm happy to help implement these changes, too.
The text was updated successfully, but these errors were encountered: