From 84a5aa8666e6f8bac25414c6eeabaf57a546a3f2 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Nov 2022 09:47:57 +0100 Subject: [PATCH] Add from_dict (#35) * Add from_dict * Fix from_namedtuple docstring * Generalize function signature * Add from_dict tests * Test also with OrderedDict * Test as_namedtuple * Increment patch number * Add from_dict to docs * Correctly import package --- Project.toml | 6 ++-- docs/src/inference_data.md | 1 + src/InferenceObjects.jl | 4 ++- src/from_dict.jl | 72 ++++++++++++++++++++++++++++++++++++++ src/from_namedtuple.jl | 2 +- src/utils.jl | 6 ++++ test/from_dict.jl | 70 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/utils.jl | 8 ++++- 9 files changed, 165 insertions(+), 5 deletions(-) create mode 100644 src/from_dict.jl create mode 100644 test/from_dict.jl diff --git a/Project.toml b/Project.toml index 49160eb0..ceb2dbde 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "InferenceObjects" uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00" authors = ["Seth Axen and contributors"] -version = "0.2.4" +version = "0.2.5" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -12,11 +12,13 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Compat = "3.46.0, 4.2.0" DimensionalData = "0.20, 0.21, 0.22, 0.23" OffsetArrays = "1" +OrderedCollections = "1" julia = "1.6" [extras] OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "OffsetArrays"] +test = ["Test", "OffsetArrays", "OrderedCollections"] diff --git a/docs/src/inference_data.md b/docs/src/inference_data.md index 2e92f204..5c3ba1af 100644 --- a/docs/src/inference_data.md +++ b/docs/src/inference_data.md @@ -33,6 +33,7 @@ That is, iterating over an `InferenceData` iterates over its groups. ```@docs convert_to_inference_data +from_dict from_namedtuple ``` diff --git a/src/InferenceObjects.jl b/src/InferenceObjects.jl index 4830bf58..2e984a18 100644 --- a/src/InferenceObjects.jl +++ b/src/InferenceObjects.jl @@ -27,7 +27,8 @@ const SCHEMA_GROUPS_DICT = Dict(n => i for (i, n) in enumerate(SCHEMA_GROUPS)) const DEFAULT_SAMPLE_DIMS = Dimensions.key2dim((:draw, :chain)) export Dataset, InferenceData -export convert_to_dataset, convert_to_inference_data, from_namedtuple, namedtuple_to_dataset +export convert_to_dataset, + convert_to_inference_data, from_dict, from_namedtuple, namedtuple_to_dataset include("utils.jl") include("dimensions.jl") @@ -36,5 +37,6 @@ include("inference_data.jl") include("convert_dataset.jl") include("convert_inference_data.jl") include("from_namedtuple.jl") +include("from_dict.jl") end # module diff --git a/src/from_dict.jl b/src/from_dict.jl new file mode 100644 index 00000000..92b968a7 --- /dev/null +++ b/src/from_dict.jl @@ -0,0 +1,72 @@ +""" + from_dict(posterior::AbstractDict; kwargs...) -> InferenceData + +Convert a `Dict` to an `InferenceData`. + +# Arguments + + - `posterior`: The data to be converted. Its strings must be `Symbol` or `AbstractString`, + and its values must be arrays. + +# Keywords + + - `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution + - `sample_stats::Any=nothing`: Statistics of the posterior sampling process + - `predictions::Any=nothing`: Out-of-sample predictions for the posterior. + - `prior::Dict=nothing`: Draws from the prior + - `prior_predictive::Any=nothing`: Draws from the prior predictive distribution + - `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process + - `observed_data::NamedTuple`: Observed data on which the `posterior` is + conditional. It should only contain data which is modeled as a random variable. Keys + are parameter names and values. + - `constant_data::NamedTuple`: Model constants, data included in the model + which is not modeled as a random variable. Keys are parameter names and values. + - `predictions_constant_data::NamedTuple`: Constants relevant to the model + predictions (i.e. new `x` values in a linear regression). + - `log_likelihood`: Pointwise log-likelihood for the data. It is recommended + to use this argument as a `NamedTuple` whose keys are observed variable names and whose + values are log likelihood arrays. + - `library`: Name of library that generated the draws + - `coords`: Map from named dimension to named indices + - `dims`: Map from variable name to names of its dimensions + +# Returns + + - `InferenceData`: The data with groups corresponding to the provided data + +# Examples + +```@example +using InferenceObjects +nchains = 2 +ndraws = 100 + +data = Dict( + :x => rand(ndraws, nchains), + :y => randn(2, ndraws, nchains), + :z => randn(3, 2, ndraws, nchains), +) +idata = from_dict(data) +``` +""" +from_dict + +function from_dict( + posterior::Union{<:AbstractDict,Nothing}=nothing; prior=nothing, kwargs... +) + nt = posterior === nothing ? posterior : as_namedtuple(posterior) + nt_prior = prior === nothing ? prior : as_namedtuple(prior) + return from_namedtuple(nt; prior=nt_prior, kwargs...) +end + +""" + convert_to_inference_data(obj::AbstractDict; kwargs...) -> InferenceData + +Convert `obj` to an [`InferenceData`](@ref). See [`from_namedtuple`](@ref) for a description +of `obj` possibilities and `kwargs`. +""" +function convert_to_inference_data(data::AbstractDict; group=:posterior, kwargs...) + group = Symbol(group) + group === :posterior && return from_dict(data; kwargs...) + return from_dict(; group => data, kwargs...) +end diff --git a/src/from_namedtuple.jl b/src/from_namedtuple.jl index 70a5c03b..102dcfbb 100644 --- a/src/from_namedtuple.jl +++ b/src/from_namedtuple.jl @@ -33,7 +33,7 @@ whose first dimensions correspond to the dimensions of the containers. - `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution - `sample_stats::Any=nothing`: Statistics of the posterior sampling process - `predictions::Any=nothing`: Out-of-sample predictions for the posterior. - - `prior::Any=nothing`: Draws from the prior + - `prior=nothing`: Draws from the prior. Accepts the same types as `posterior`. - `prior_predictive::Any=nothing`: Draws from the prior predictive distribution - `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process - `observed_data::NamedTuple`: Observed data on which the `posterior` is diff --git a/src/utils.jl b/src/utils.jl index 2270ff12..5f89d1b1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -67,3 +67,9 @@ function rekey(d::NamedTuple, keymap) new_keys = map(k -> get(keymap, k, k), keys(d)) return NamedTuple{new_keys}(values(d)) end + +as_namedtuple(dict::AbstractDict{Symbol}) = NamedTuple(dict) +function as_namedtuple(dict::AbstractDict{<:AbstractString}) + return NamedTuple(Symbol(k) => v for (k, v) in dict) +end +as_namedtuple(nt::NamedTuple) = nt diff --git a/test/from_dict.jl b/test/from_dict.jl new file mode 100644 index 00000000..13e82519 --- /dev/null +++ b/test/from_dict.jl @@ -0,0 +1,70 @@ +using InferenceObjects, OrderedCollections, Test + +@testset "from_dict" begin + nchains, ndraws = 4, 10 + sizes = (x=(), y=(2,), z=(3, 5)) + dims = (y=[:yx], z=[:zx, :zy]) + coords = (yx=["y1", "y2"], zx=1:3, zy=1:5) + + dicts = [ + "Dict{Symbol}" => + Dict(Symbol(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)), + "OrderedDict{String}" => + Dict(string(k) => randn(sz..., ndraws, nchains) for (k, sz) in pairs(sizes)), + ] + + @testset "posterior::$(type)" for (type, dict) in dicts + @test_broken @inferred from_dict(dict; dims, coords, library="MyLib") + idata1 = from_dict(dict; dims, coords, library="MyLib") + idata2 = convert_to_inference_data(dict; dims, coords, library="MyLib") + test_idata_approx_equal(idata1, idata2) + end + + @testset "$(group)" for group in [ + :posterior_predictive, :sample_stats, :predictions, :log_likelihood + ] + library = "MyLib" + @testset "::$(type)" for (type, dict) in dicts + idata1 = from_dict(dict; group => dict, dims, coords, library) + test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords) + + idata2 = from_dict(dict; group => (:x,), dims, coords, library) + test_idata_group_correct(idata2, :posterior, (:y, :z); library, dims, coords) + test_idata_group_correct(idata2, group, (:x,); library, dims, coords) + end + end + + @testset "$(group)" for group in [:prior_predictive, :sample_stats_prior] + library = "MyLib" + @testset "::$(type)" for (type, dict) in dicts + idata1 = from_dict(; prior=dict, group => dict, dims, coords, library) + test_idata_group_correct(idata1, :prior, keys(sizes); library, dims, coords) + test_idata_group_correct(idata1, group, keys(sizes); library, dims, coords) + + idata2 = from_dict(; prior=dict, group => (:x,), dims, coords, library) + test_idata_group_correct(idata2, :prior, (:y, :z); library, dims, coords) + test_idata_group_correct(idata2, group, (:x,); library, dims, coords) + end + end + + @testset "$(group)" for group in + [:observed_data, :constant_data, :predictions_constant_data] + _, dict = dicts[1] + library = "MyLib" + dims = (; w=[:wx]) + coords = (; wx=1:2) + idata1 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library) + test_idata_group_correct(idata1, :posterior, keys(sizes); library, dims, coords) + test_idata_group_correct( + idata1, group, (:w,); library, dims, coords, default_dims=() + ) + + # ensure that dims are matched to named tuple keys + # https://github.com/arviz-devs/ArviZ.jl/issues/96 + idata2 = from_dict(dict; group => Dict(:w => [1.0, 2.0]), dims, coords, library) + test_idata_group_correct(idata2, :posterior, keys(sizes); library, dims, coords) + test_idata_group_correct( + idata2, group, (:w,); library, dims, coords, default_dims=() + ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3004af19..55a52053 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,4 +9,5 @@ using InferenceObjects, Test include("convert_dataset.jl") include("convert_inference_data.jl") include("from_namedtuple.jl") + include("from_dict.jl") end diff --git a/test/utils.jl b/test/utils.jl index a1c0b4f5..a15bd2a0 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,4 @@ -using InferenceObjects, Test +using InferenceObjects, OrderedCollections, Test module TestSubModule end @@ -40,4 +40,10 @@ module TestSubModule end @test new == Dict(:y => 3, :a => 4, :z => 5) end end + + @testset "as_namedtuple" begin + @test InferenceObjects.as_namedtuple(OrderedDict(:x => 3, :y => 4)) === (x=3, y=4) + @test InferenceObjects.as_namedtuple(OrderedDict("x" => 4, "y" => 5)) === (x=4, y=5) + @test InferenceObjects.as_namedtuple((y=6, x=7)) === (y=6, x=7) + end end