Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed redesign of the conversion pipeline #32

Open
sethaxen opened this issue Nov 3, 2022 · 0 comments
Open

Proposed redesign of the conversion pipeline #32

sethaxen opened this issue Nov 3, 2022 · 0 comments
Labels
enhancement New feature or request

Comments

@sethaxen
Copy link
Member

sethaxen commented Nov 3, 2022

Currently we have two classes of conversion functions. The first class consists of from_XXX functions, which dispatch on the posterior type and have a number of keywords specific to that type. e.g. from_namedtuple or from_mcmcchains. Then we have the generic functions convert_to_inference_data and convert_to_dataset, which have methods that dispatch to the from_XXX functions. These functions can even be used within other from_XXX functions to allow groups of one type to be mixed with a posterior of another type.

When a user wants their type to be convertible to an InferenceData, they in general implement a from_XXX function and a special convert_to_inference_data method.

Here I propose a major design of this pipeline. Here are some principles we use:

  • There are 2 types of objects we want to convert: objects that contain data for a single group (or part of a group) and objects that contain data for multiple groups.
  • A user may want to split the data in the first type of object into several groups.
  • We drop the prefix convert_to_, since we're not in general doing conversion.
  • The first type of object can be hooked into the pipeline by implementing a single function dataset.
  • The second type of object can be hooked into the pipeline by implementing the functions inferencedata and dataset.
  • All conversion functions should absorb unused keywords into kwargs, so that a single inferencedata call can use keywords for multiple conversion methods so long as they don't clash.

Working prototype of the pipeline

using InferenceObjects

# fallback to current pipeline for demonstration purposes
dataset(data; kwargs...) = convert_to_dataset(data; kwargs...)

inferencedata(data::InferenceData; kwargs...) = data
inferencedata(data; kwargs...) = inferencedata(:posterior => data; kwargs...)
function inferencedata(data::Pair{Symbol}; kwargs...)
    k, v = data
    ds = if k  (:constant_data, :observed_data)
        dataset(v; default_dims=(), kwargs...)
    else
        dataset(v; kwargs...)
    end
    return InferenceData(; k => ds)
end
function inferencedata(data, next::Pair{Symbol}, others::Pair{Symbol}...; kwargs...)
    inferencedata(inferencedata(data; kwargs...), next, others...; kwargs...)
end
function inferencedata(data::InferenceData, next::Pair{Symbol}, others::Pair{Symbol}...; kwargs...)
    merge(data, inferencedata(next; kwargs...), others...; kwargs...)
end

struct Subset{V}
    source::Symbol
    var_map::V
end
function subset(source::Symbol, var_map::Tuple{Vararg{Union{Symbol,Pair{Symbol,Symbol}}}})
    var_map_new = map(var_map) do v
        v isa Pair && return v
        return v => v
    end
    return Subset(source, var_map_new)
end
function inferencedata(data::InferenceData, next::Pair{Symbol,<:Subset}, others::Pair{Symbol}...; kwargs...)
    k, s = next
    source_vars = map(last, s.var_map)
    source_ds = data[s.source]
    source_ds_new = source_ds[filter((source_vars), keys(source_ds))]
    subset_nt = NamedTuple(source_ds[source_vars])
    subset = Dataset(NamedTuple{map(first, s.var_map)}(values(subset_nt)))
    idata_merged = merge(data, InferenceData(; s.source => source_ds_new, k => subset))
    return inferencedata(idata_merged, others...; kwargs...)
end

Demonstration

Now here's a demonstration of how we use it:

julia> ndraws, nchains = 1_000, 4;

julia> data_all = (
           x = randn(4, ndraws, nchains),
           z = randn(2, ndraws, nchains),
           lp = randn(ndraws, nchains),
           log_like = randn(10, ndraws, nchains),
           y_hat = randn(10, ndraws, nchains),
       );

julia> idata = inferencedata(
           data_all,
           :posterior_predictive => subset(:posterior, (:y => :y_hat,)),
           :log_likelihood => subset(:posterior, (:y => :log_like,)), 
           :sample_stats => subset(:posterior, (:lp,)),
       )
InferenceData with groups:
  > posterior
  > posterior_predictive
  > log_likelihood
  > sample_stats

julia> idata.posterior
Dataset with dimensions: 
  Dim{:x_dim_1} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:z_dim_1} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points
and 2 layers:
  :x Float64 dims: Dim{:x_dim_1}, Dim{:draw}, Dim{:chain} (4×1000×4)
  :z Float64 dims: Dim{:z_dim_1}, Dim{:draw}, Dim{:chain} (2×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-11-03T23:09:29.868"

julia> idata.posterior_predictive
Dataset with dimensions: 
  Dim{:y_hat_dim_1} Sampled{Int64} Base.OneTo(10) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
  :y Float64 dims: Dim{:y_hat_dim_1}, Dim{:draw}, Dim{:chain} (10×1000×4)


julia> idata.log_likelihood
Dataset with dimensions: 
  Dim{:log_like_dim_1} Sampled{Int64} Base.OneTo(10) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
  :y Float64 dims: Dim{:log_like_dim_1}, Dim{:draw}, Dim{:chain} (10×1000×4)


julia> idata.sample_stats
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
  :lp Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)

The subseting machinery is generic, so we don't need to customize it for every type like we currently do in the from_XXX methods.

There are still some kinks to work out in this pipeline, like correct handling of dimensions when the variables are renamed, but let's check the extensibility of the pipeline.

Demonstration of pipeline extensibility

Here we define two types of storage of MCMC results, representing the two types defined above.

# represents some object containing data from a single dataset
struct PosteriorStorage
    nt
end
dataset(post::PosteriorStorage; kwargs...) = dataset(post.nt; kwargs...)

# represents some object containing data from multiple datasets, here a posterior and sample_stats
# we allow it to be converted to an InferenceData or to a dataset, in which case a single Dataset is extracted, here the posterior
# e.g. MCMCChains.Chains or SampleChains.MultiChain
struct MultiGroupStorage
    nt
end
function inferencedata(post::MultiGroupStorage; kwargs...)
    inferencedata(post.nt, :sample_stats=>subset(:posterior, (:lp,)); kwargs...);
end
dataset(post::MultiGroupStorage; kwargs...) = inferencedata(post; kwargs...).posterior

Now let's wrap our NamedTuple in these types and execute the pipeline:

julia> idata2 = inferencedata(
           PosteriorStorage(data_all),
           :posterior_predictive => subset(:posterior, (:y => :y_hat,)),
           :log_likelihood => subset(:posterior, (:y => :log_like,)), 
           :sample_stats => subset(:posterior, (:lp,)),
       )
InferenceData with groups:
  > posterior
  > posterior_predictive
  > log_likelihood
  > sample_stats

julia> idata2.posterior
Dataset with dimensions: 
  Dim{:x_dim_1} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:z_dim_1} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points
and 2 layers:
  :x Float64 dims: Dim{:x_dim_1}, Dim{:draw}, Dim{:chain} (4×1000×4)
  :z Float64 dims: Dim{:z_dim_1}, Dim{:draw}, Dim{:chain} (2×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-11-03T23:23:57.729"

julia> idata2.posterior_predictive
Dataset with dimensions: 
  Dim{:y_hat_dim_1} Sampled{Int64} Base.OneTo(10) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
  :y Float64 dims: Dim{:y_hat_dim_1}, Dim{:draw}, Dim{:chain} (10×1000×4)


julia> idata3 = inferencedata(
           MultiGroupStorage(data_all),
           :posterior_predictive => subset(:posterior, (:y => :y_hat,)),
           :log_likelihood => subset(:posterior, (:y => :log_like,)), 
       )
InferenceData with groups:
  > posterior
  > posterior_predictive
  > log_likelihood
  > sample_stats

julia> idata3.posterior
Dataset with dimensions: 
  Dim{:x_dim_1} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:z_dim_1} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points
and 2 layers:
  :x Float64 dims: Dim{:x_dim_1}, Dim{:draw}, Dim{:chain} (4×1000×4)
  :z Float64 dims: Dim{:z_dim_1}, Dim{:draw}, Dim{:chain} (2×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-11-03T23:24:29.005"

julia> idata3.sample_stats
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
  :lp Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant