Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting more shapes #78

Closed
sethaxen opened this issue Apr 8, 2023 · 2 comments
Closed

Supporting more shapes #78

sethaxen opened this issue Apr 8, 2023 · 2 comments

Comments

@sethaxen
Copy link
Member

sethaxen commented Apr 8, 2023

Currently the more modern methods support the following shapes:

  • bfmi: (ndraws[, nchains]) -> ([nchains,])
  • mcse/rhat/ess/ess_rhat: (ndraws, nchains, nparams) -> (nparams,)
  • rstar: (ndraws, nchains, nparams) -> result, other shapes supported indirectly via Tables interface.

Looking at these, I think there's a clear continuum of shapes we can interpret and support: (ndraws[, nchains[, nparams...]]). While bfmi can obviously only support the first 2 dimensions, the others can also support the vector case and trailing param dimensions, as might result from stacking chains of draws of matrix random variables.

Why do this? While a user can always reshape to a 3D array, this is not ideal for arrays with named dimensions/indices, since reshape causes all named dimensions to be lost. e.g.

julia> using DimensionalData, MCMCDiagnosticTools

julia> da1 = DimArray(randn(1000, 1, 1), (:draw, :chain, :param));

julia> ess(da1)  # ideal case, dimensions preserved
1-element DimArray{Float64,1} with dimensions: Dim{:param}
 1  960.129

julia> da2 = DimArray(randn(1000, 1), (:draw, :chain));

julia> ess(reshape(da2, size(da2)..., 1))  # named dimensions lost
1-element Vector{Float64}:
 1036.1139385722638

julia> da3 = DimArray(randn(1000), (:draw,));

julia> ess(reshape(da3, size(da3)..., 1, 1))  # named dimensions lost
1-element Vector{Float64}:
 929.3129270202691

julia> da4 = DimArray(randn(1000, 4, 3, 4), (:draw, :chain, :param1, :param2));

julia> ess(reshape(da4, size(da4,1), size(da4,2), :))  # named dimensions lost
12-element Vector{Float64}:
 3950.8133148219445
 3963.51341805499
 3887.7997316174083
 4058.9638959410036
 3685.782636881455
 3633.6820465330984
 3868.9859706770994
 3792.095865124389
 3921.8966121239305
 4111.7648208639375
 3561.7107571982556
 3508.2107768022665

In ArviZ, which uses DimensionalData.DimArrays to store draws, this requires quite a bit of boilerplate whenever we call one of these methods to reshape for MCMCDiagnosticTools, then unreshape the result, then add back dimensions. The proposed generalization is still unambiguous and allows the functions to be used more ergonomically in such cases.

@sethaxen
Copy link
Member Author

@devmotion what do you think of this proposal?

@sethaxen
Copy link
Member Author

Fixed by #79

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant