Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to dimension ordering #49

Closed
sethaxen opened this issue Nov 18, 2022 · 8 comments
Closed

Changes to dimension ordering #49

sethaxen opened this issue Nov 18, 2022 · 8 comments

Comments

@sethaxen
Copy link
Member

sethaxen commented Nov 18, 2022

Following the discussion in #5, I propose the following dimension interpretations be used uniformly across the package.

  • AbstractVector: (draws,), vector of MC draws for a single parameter
  • AbstractArray{<:Any,3}: (params, draws, chains), array of MCMC draws for multiple parameters

The first is obvious and consistent with the current interpretation. The second is consistent with Julia's default column ordering, as I've explained in arviz-devs/InferenceObjects.jl#8.

I also propose that all of our common diagnostics ultimately implement methods for the AbstractArray{<:Any,3} signature. This allows the user to try multiple diagnostics on a single array format without needing to do different reshapes and slices for each diagnostic, but it doesn't require users use AbstractArray{<:Any,3} if their draws are not in that format.

When outputs contain a subset of these dimensions, they should preserve the order of the dimensions. I don't think this requires any changes right now.

Concrete changes

Breaking:

  • discretediag, ess, gelmandiag: (draws, params, chains) -> (params, draws, chains)

New methods:

  • rstar
    • change sample matrix input to have shape (params, draws) instead of (draws, params)
    • add method with (params, draws, chains) input, which is then forwarded to the current method after a reshape and repeat (to generate chain indices)
  • mcse: add an AbstractArray{<:Any,3} method.

Things not changed:

  • bfmi: it makes no sense for energy to have a params dimension, so bfmi does not need a 3d array method
  • gewekediag, heideldiag, rafterydiag: I suspect these are rarely used, so adding AbstractArray{<:Any,3} methods would be low priority

cc @devmotion @ParadaCarleton

@devmotion
Copy link
Member

devmotion commented Nov 21, 2022

I guess the main drawback is that we would have to use permutedims in MCMCChains which currently uses the layout (draws, params, chains). It would be good to check if/how that impacts performance of the analysis there. I assume if we manage to reduce the amount of permutedims by not re-doing it for every analysis when evaluating multiple methods then it should be fine.

In the long run maybe we might want to change the memory layout in MCMCChains as well - but maybe that's solved by TuringLang/MCMCChains.jl#381 automatically.

@sethaxen
Copy link
Member Author

I guess the main drawback is that we would have to use permutedims in MCMCChains which currently uses the layout (draws, params, chains). It would be good to check if/how that impacts performance of the analysis there. I assume if we manage to reduce the amount of permutedims by not re-doing it for every analysis when evaluating multiple methods then it should be fine.

We could use PermutedDimsArray instead, so perhaps this won't have such a big impact. I agree we should benchmark this though for representative Chains objects.

In the long run maybe we might want to change the memory layout in MCMCChains as well - but maybe that's solved by TuringLang/MCMCChains.jl#381 automatically.

Changing the memory layout would be a major breaking change for downstream user code, but yes, it's worth considering. The linked PR would solve it if a user converted a Chains to an InferenceData once (internally permutes the dims) and then used the InferenceData for subsequent analyses.

@sethaxen
Copy link
Member Author

sethaxen commented Nov 21, 2022

I ran a benchmark comparing the current version of MCMCChains with a locally updated version depending on the latest commits in #50 using either permutedims or PermuteDimsArray.

This was the benchmark:

using Random, BenchmarkTools, JLD2, MCMCChains
Random.seed!(42)
val = rand(1_000, 100, 8)
chn = Chains(val, 1:100)
suite = BenchmarkGroup()
# suite["discretediag"] = @benchmarkable discretediag($chn)
suite["ess_rhat"] = @benchmarkable ess_rhat($chn)
suite["gelmandiag"] = @benchmarkable gelmandiag($chn)
suite["gelmandiag_multivariate"] = @benchmarkable gelmandiag_multivariate($chn)
# suite["rstar"] = @benchmarkable rstar(rng, $classifier, $chn) setup=(rng = MersenneTwister(42));
results = run(suite; verbose = true)

Here are the combined results showing mean and std (in microseconds):

julia> DataFrame(d)
3×4 DataFrame
 Row │                        ess_rhat            gelmandiag          gelmandiag_multivariate 
     │ String                 Tuple              Tuple              Tuple                  
─────┼────────────────────────────────────────────────────────────────────────────────────────
   1 │ old                    (3.54742, 1.05944)  (32.2188, 317.828)  (24.0996, 41.8181)
   2 │ new_PermutedDimsArray  (3.4504, 0.736588)  (40.4049, 358.265)  (25.2606, 39.8565)
   3 │ new_permutedims        (5.73569, 1.74746)  (48.7768, 408.437)  (26.1484, 41.6385)

And here's the minimum:

julia> DataFrame(d)
3×4 DataFrame
 Row │                        ess_rhat  gelmandiag  gelmandiag_multivariate 
     │ String                 Float64   Float64     Float64                 
─────┼──────────────────────────────────────────────────────────────────────
   1 │ old                     3.02097     4.46727                  18.5102
   2 │ new_PermutedDimsArray   2.92153     5.56936                  20.3158
   3 │ new_permutedims         4.16976     5.39286                  20.3697

Based on this benchmark, I'd suggest using PermuteDimsArray in MCMCChains, and then there's not much of a performance regression.

@sethaxen
Copy link
Member Author

Looking at it closer, for at least some variants of mcse it makes sense to be able to specify the dims that correspond to the draws. e.g. for x with shape (nparameters, ndraws, nchains), then mcse(x; dims=2) would return a (nparameters, nchains) matrix of MCSE values, while mcse(x; dims=(2, 3)) would return a (nparameters,) vector of MCSE values, having merged the chains. e.g. MCMCChains currently does the former for computing gewekediag and heideldiag: https://github.com/TuringLang/MCMCChains.jl/blob/master/src/gewekediag.jl

We already use dims to mean "draw dims" for bfmi. I wonder if we should add this keyword then wherever it makes sense. A downside is that one could choose an MCSE method that requires separating the chains, but the dims keyword could contradict this.

@ParadaCarleton
Copy link
Member

ParadaCarleton commented Nov 23, 2022

I think switching to a common dimension ordering is a great idea and we should get to work on it, although I've since realized it's not clear what's going to be the best layout for memory locality-- x[params, draws, chains] is the most efficient when sampling (since all parameters are drawn together in one sample), but x[draws, chains, params] is most efficient for analysis (because it's common to compute summary metrics dimensionwise or chainwise, but not one draw at a time).

In any case, memory-layout is unlikely to be the bottleneck in the MCMC pipeline, so we should probably go with whatever is most natural or most familiar to users (e.g. matching ArviZ's layout in Python).

Julia 1.9 might help simplify all of this when Slices are introduced--we can have users pass sliced copies of their arrays, with each slice being a separate chain, giving us a more natural "vector of chains" interpretation.

@sethaxen
Copy link
Member Author

@ParadaCarleton after some thought, there are a number of good reasons to use (draw, chain, param...) as a useful default ordering, and I've opened a PR to do so in InferenceObjects: arviz-devs/InferenceObjects.jl#40. I'll similarly update #50 to use this ordering.

@sethaxen
Copy link
Member Author

Here are some updated benchmarks from #49 (comment).

mean and std (in microseconds):

3×4 DataFrame
 Row │                        ess_rhat            gelmandiag          gelmandiag_multivariate 
     │ String                 Tuple              Tuple              Tuple                  
─────┼────────────────────────────────────────────────────────────────────────────────────────
   1 │ old                    (3.41135, 1.11612)  (26.4841, 285.93)   (23.446, 36.6565)
   2 │ new_PermutedDimsArray  (3.58941, 9.25398)  (6.53272, 25.1717)  (21.2084, 3.2262)
   3 │ new_permutedims        (4.70588, 1.4936)   (6.81624, 1.48399)  (24.4894, 4.97561)

minimum:

3×4 DataFrame
 Row │                        ess_rhat  gelmandiag  gelmandiag_multivariate 
     │ String                 Float64   Float64     Float64                 
─────┼──────────────────────────────────────────────────────────────────────
   1 │ old                     2.60093     3.97421                  17.3264
   2 │ new_PermutedDimsArray   2.54193     3.85993                  17.4864
   3 │ new_permutedims         3.04878     4.55253                  18.4763

The runtimes with PermutedDimsArray are more or less equivalent to the old runtimes, except for gelmandiag, for which the new dimension order is on average faster.

@sethaxen
Copy link
Member Author

Implemented in #50

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants