Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better supporting draws that are arbitrary Julia types #11

Open
sethaxen opened this issue Aug 25, 2022 · 4 comments
Open

Better supporting draws that are arbitrary Julia types #11

sethaxen opened this issue Aug 25, 2022 · 4 comments

Comments

@sethaxen
Copy link
Member

This issue continues discussion starting at #8 (comment).

Some Julia PPLs can return draws as arbitrary Julia types. Here's an example with Soss:

julia> using Soss

julia> struct Foo
           x
           y
           tag
       end

julia> mod1 = @model n begin
           x ~ Normal()
           y ~ Normal() |> iid(n)
           return Foo(x, y, :discrete)
       end;

julia> mod2 = @model n begin
           t ~ mod1(n)
           z ~ Normal=t.x)
       end;

julia> rand(mod2(3))
(t = Foo(-1.8692299945695137, [-0.9020275201942468, -0.9380392196474631, -0.041490841817294566], :discrete), z = -3.0169797376605327)

Currently such types can be stored in InferenceData:

julia> using InferenceObjects

julia> data = (; a = [rand(mod2(3)) for _ in 1:4, _ in 1:100]);

julia> idata = InferenceData(posterior=namedtuple_to_dataset(data))
InferenceData with groups:
  > posterior

julia> idata.posterior
Dataset with dimensions: 
  Dim{:chain} Sampled Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:draw} Sampled Base.OneTo(100) ForwardOrdered Regular Points
and 1 layer:
  :a NamedTuple{(:t, :z), Tuple{Foo, Float64}} dims: Dim{:chain}, Dim{:draw} (4×100)

with metadata OrderedCollections.OrderedDict{Symbol, Any} with 1 entry:
  :created_at => "2022-08-25T11:15:48.582"

julia> idata.posterior.a
4×100 DimArray{NamedTuple{(:t, :z), Tuple{Foo, Float64}},2} a with dimensions: 
  Dim{:chain} Sampled Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:draw} Sampled Base.OneTo(100) ForwardOrdered Regular Points
      100
 1        (t = Foo(-0.332783, [-0.271914, -1.19732, 0.239832], :discrete), z = -0.473026)
 2        (t = Foo(-1.03842, [-0.148646, -0.102317, 0.242476], :discrete), z = -1.53602)
 3        (t = Foo(0.902033, [-0.798571, 0.173176, 0.533269], :discrete), z = 1.45302)
 4        (t = Foo(-0.98641, [1.89491, -0.674791, -0.203847], :discrete), z = -1.44689)

So InferenceData can be used for this storage, but it's not very useful, for several reasons:

  • All downstream diagnostics, statistics, plots, and serialization to NetCDF/Zarr will require access to marginals, so we need flat multidimensional arrays, often with numeric types.
  • The Tables interface is not useful for such types (see below), so users can't easily construct custom plots
  • We can't assign named dimensions to arrays in such nested objects, which are very useful for plotting.

Here's an example of what the Tables interface would produce:

julia> using DataFrames

julia> DataFrame(idata.posterior)
400×3 DataFrame
 Row │ chain  draw   a                                 
     │ Int64  Int64  NamedTup                         
─────┼─────────────────────────────────────────────────
   11      1  (t = Foo(0.235607, [1.08405, 0.9
   22      1  (t = Foo(-1.24972, [-1.89301, 0.
   33      1  (t = Foo(1.0526, [-0.179664, -0.
   44      1  (t = Foo(0.793393, [0.558985, 0.
                            
 3982    100  (t = Foo(-1.03842, [-0.148646, -
 3993    100  (t = Foo(0.902033, [-0.798571, 0
 4004    100  (t = Foo(-0.98641, [1.89491, -0.
                                       393 rows omitted

So plotting packages that use the Tables interface, like AlgebraOfGraphics and StatsPlots, are not terribly useful here without lots of additional code.

There are several ways we might approach this:

  1. Do nothing. Users are free to use arbitrary types with InferenceData, and they are expected to turn their types into whatever marginals they care about when they want to use the downstream functions we discussed above. This is the current state.
  2. Require all converters flatten to the marginals. The converter might encode some of the structure into the Dataset. e.g. the above example might be converted to a Dataset with variable names a.t.x, a.t.y, a.t.tag, and a.z. If we go this route, InferenceData would be a secondary data type used only for some analyses but not a possible default for such PPLs, since it loses some of the structure in the initial draws.
  3. Define an interface for computing a "marginal representation" of a variable, dataset, or whole InferenceData. This would be called by the user to convert a non-flattened InferenceData to a flattened one, allowing provision of named dimensions. e.g. such a function would map the above posterior to something like:
julia> using Compat

julia> a = idata.posterior.a;

julia> d = (;
           var"a.t.x"=map(x -> x.t.x, a),
           var"a.t.y"=permutedims(Compat.stack(map(x -> x.t.y, a)), (2, 3, 1)),
           var"a.t.tag"=map(x -> x.t.tag, a),
           var"a.z"=map(x -> x.z, a),
       );

julia> post_new = namedtuple_to_dataset(d)
Dataset with dimensions: 
  Dim{:chain} Sampled Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:draw} Sampled Base.OneTo(100) ForwardOrdered Regular Points,
  Dim{:a.t.y_dim_1} Sampled Base.OneTo(3) ForwardOrdered Regular Points
and 4 layers:
  :a.t.x   Float64 dims: Dim{:chain}, Dim{:draw} (4×100)
  :a.t.y   Float64 dims: Dim{:chain}, Dim{:draw}, Dim{:a.t.y_dim_1} (4×100×3)
  :a.t.tag Symbol dims: Dim{:chain}, Dim{:draw} (4×100)
  :a.z     Float64 dims: Dim{:chain}, Dim{:draw} (4×100)

with metadata OrderedCollections.OrderedDict{Symbol, Any} with 1 entry:
  :created_at => "2022-08-25T12:08:52.898"

julia> DataFrame(post_new)
1200×7 DataFrame
  Row │ chain  draw   a.t.y_dim_1  a.t.x      a.t.y       a.t.tag   a.z        
      │ Int64  Int64  Int64        Float64    Float64     Symbol    Float64    
──────┼────────────────────────────────────────────────────────────────────────
    11      1            1   0.235607   1.08405    discrete   1.20053
    22      1            1  -1.24972   -1.89301    discrete   0.0482549
    33      1            1   1.0526    -0.179664   discrete   4.70509
    44      1            1   0.793393   0.558985   discrete   0.919428
    51      2            1  -1.53217    1.41717    discrete  -1.89603
                                                          
 11964     99            3  -1.65774    0.689958   discrete  -1.00366
 11971    100            3  -0.332783   0.239832   discrete  -0.473026
 11982    100            3  -1.03842    0.242476   discrete  -1.53602
 11993    100            3   0.902033   0.533269   discrete   1.45302
 12004    100            3  -0.98641   -0.203847   discrete  -1.44689
                                                              1190 rows omitted

The easiest way I can think of to provide such a default is to recur through all Julia types and allocate new arrays as done above, but there may be other options using custom Julia arrays. @oschulz, @cscherrer, would the types you have been suggesting allow for this?

  1. Develop a completely new data structure that allows arbitrarily deep nesting of data types and assignment of dimensions at any level but also implements both a marginal no-copy view that flattens everything and a tabular no-copy view that further concatenates with useful column names. I don't see immediately how this could be done with existing dimensional data types, so it could be as complicated as developing yet another dimensional data package.

Off the top of my head, a few additional criteria for the solution:

  1. The InferenceData type and its basic functionality must be kept in a lightweight package and as generic as possible. It's not even ideal that we depend on DimensionalData, but so we do. If we require a complicated solution with lots of dependencies, this should be its own package, which PPLs or packages with PPL-specific converters can then depend on.
  2. We can allow type piracy for packages within this organization if necessary, but that's it.
  3. The solution should ideally not require the average user or PPL developer to implement some API for their custom types, i.e. there should be sensible defaults.
  4. While we're focusing on increasing usability within Julia, we cannot sacrifice serialization to data structures like NetCDF for archiving and interop with other languages.

Since the others tagged in this have thought a lot more about this than I have, I'd appreciate any input/suggestions. cc also @ParadaCarleton

@sethaxen
Copy link
Member Author

@femtomc I wonder if you have input on this as well, since I think Gen traces also can be nested and contain arbitrary Julia types.

@oschulz
Copy link

oschulz commented Aug 26, 2022

Something that may be helpful in this context: @cscherrer and me had discussed to built flatten/unflatten transformations on top of the now transport API in MeasureBase.jl. This would allow for automatically generating transforms to/from flat vectors as long as a prior measure is available (it would provide the required structural information).

@oschulz
Copy link

oschulz commented Aug 26, 2022

Another thing that may be interesting in this contect: In BAT.jl we've recently added the ability to marginalize/flatten structures to "flat" NamedTuples using unicode. This is currently limited to non-nested input, but the result looks like this: A value (a = [1.2, 2.3], b = 4.2) can be turned into (a⌞1⌟ = 1.2, a⌞2⌟ = 2.3, b = 4.2). We use a few other unicode characters too so we can preserve range-selection during marginalization and have valid unicode field names like (d⌞1ː2⌟ = ...). We introduced it to support value selection for plotting, but we're planning to extend it and make it more directly accessible. Maybe such a "flatten-nested-names-and-ranges-to-unicode" scheme could be useful for arviz as well?

@cscherrer
Copy link

Some more transform discussion (hidden to avoid distracting) I think it's also worth noting that transforms/transports can be data-dependent. I don't know how Turing does things, but TransformVariables wants the transform to be static. For Tilde, I think we can make things a lot more flexible. Instead of running the model once to determine the transport, the transform can be more dynamic and be itself in terms of a model run.

For inference, the simple approach would then run the model twice: once to get the transformed value, and again to get the log-density. But that's inefficient, so I think we'll compute the log-density along the way. If there's a Vector of samples available, we'll write into that as we go.

All downstream diagnostics, statistics, plots, and serialization to NetCDF/Zarr will require access to marginals, so we need flat multidimensional arrays, often with numeric types.

@sethaxen I think of TupleVectors as making it easy to get to marginals, so maybe I don't understand what you mean by "marginals". Can you give more detail on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants