diff --git a/Project.toml b/Project.toml index 95342249c..97969944d 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.8.4, 0.9" +AbstractPPL = "0.10.1" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 52200171d..06cde3bac 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -42,6 +42,148 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +""" + predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) + +Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample +in `chain`, and return the resulting `Chains`. + +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to +simulate what new, unobserved data might look like, given your posterior beliefs. + +For each parameter configuration in `chain`: +1. All random variables present in `chain` are fixed to their sampled values. +2. Any variables not included in `chain` are sampled from their prior distributions. + +If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by +the samples in `chain`. This is useful when you want to sample only new variables from the posterior +predictive distribution. + +# Examples +```jldoctest +using AbstractMCMC, Distributions, DynamicPPL, Random + +@model function linear_reg(x, y, σ = 0.1) + β ~ Normal(0, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end +end + +# Generate synthetic chain using known ground truth parameter +ground_truth_β = 2.0 + +# Create chain of samples from a normal distribution centered on ground truth +β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), [:β,] +) + +# Generate predictions for two test points +xs_test = [10.1, 10.2] + +m_train = linear_reg(xs_test, fill(missing, length(xs_test))) + +predictions = DynamicPPL.AbstractPPL.predict( + Random.default_rng(), m_train, β_chain +) + +ys_pred = vec(mean(Array(predictions); dims=1)) + +# Check if predictions match expected values within tolerance +( + isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01), + isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01) +) + +# output + +(true, true) +``` +""" +function DynamicPPL.predict( + rng::DynamicPPL.Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains; + include_all=false, +) + parameter_only_chain = MCMCChains.get_sections(chain, :parameters) + varinfo = DynamicPPL.VarInfo(model) + + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + predictive_samples = map(iters) do (sample_idx, chain_idx) + DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) + model(rng, varinfo, DynamicPPL.SampleFromPrior()) + + vals = DynamicPPL.values_as_in_model(model, varinfo) + varname_vals = mapreduce( + collect, + vcat, + map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), + ) + + return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) + end + + chain_result = reduce( + MCMCChains.chainscat, + [ + _predictive_samples_to_chains(predictive_samples[:, chain_idx]) for + chain_idx in 1:size(predictive_samples, 2) + ], + ) + parameter_names = if include_all + MCMCChains.names(chain_result, :parameters) + else + filter( + k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)), + names(chain_result, :parameters), + ) + end + return chain_result[parameter_names] +end + +function _predictive_samples_to_arrays(predictive_samples) + variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() + + sample_dicts = map(predictive_samples) do sample + varname_value_pairs = sample.varname_and_values + varnames = map(first, varname_value_pairs) + values = map(last, varname_value_pairs) + for varname in varnames + push!(variable_names_set, varname) + end + + return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values)) + end + + variable_names = collect(variable_names_set) + variable_values = [ + get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts), + key in variable_names + ] + + return variable_names, variable_values +end + +function _predictive_samples_to_chains(predictive_samples) + variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples) + variable_names_symbols = map(Symbol, variable_names) + + internal_parameters = [:lp] + log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1) + + parameter_names = [variable_names_symbols; internal_parameters] + parameter_values = hcat(variable_values, log_probabilities) + parameter_values = MCMCChains.concretize(parameter_values) + + return MCMCChains.Chains( + parameter_values, parameter_names, (internals=internal_parameters,) + ) +end + """ returned(model::Model, chain::MCMCChains.Chains) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 21fdcebc9..c1cdbd94e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -5,7 +5,7 @@ using AbstractPPL using Bijectors using Compat using Distributions -using OrderedCollections: OrderedDict +using OrderedCollections: OrderedCollections, OrderedDict using AbstractMCMC: AbstractMCMC using ADTypes: ADTypes @@ -40,6 +40,8 @@ import Base: keys, haskey +import AbstractPPL: predict + # VarInfo export AbstractVarInfo, VarInfo, diff --git a/src/model.jl b/src/model.jl index 0214d5feb..2bad6f1fe 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1144,6 +1144,26 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end +""" + predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) + +Generate samples from the posterior predictive distribution by evaluating `model` at each set +of parameter values provided in `chain`. The number of posterior predictive samples matches +the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values +and the predicted values. +""" +function predict( + rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} +) + varinfo = DynamicPPL.VarInfo(model) + return map(chain) do params_varinfo + vi = deepcopy(varinfo) + DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) + model(rng, vi, SampleFromPrior()) + return vi + end +end + """ returned(model::Model, parameters::NamedTuple) returned(model::Model, values, keys) diff --git a/test/Project.toml b/test/Project.toml index 8d9b19bce..c7583c672 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -33,7 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.8.4, 0.9" +AbstractPPL = "0.10.1" Accessors = "0.1" Bijectors = "0.15.1" Combinatorics = "1" diff --git a/test/contexts.jl b/test/contexts.jl index 0f6628440..7a7826466 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -202,7 +202,7 @@ end s, m = retval.s, retval.m # Keword approach. - model_fixed = fix(model; s=s) + model_fixed = DynamicPPL.fix(model; s=s) @test model_fixed().s == s @test model_fixed().m != m # A fixed variable should not contribute at all to the logjoint. @@ -210,19 +210,19 @@ end @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) # Positional approach. - model_fixed = fix(model, (; s)) + model_fixed = DynamicPPL.fix(model, (; s)) @test model_fixed().s == s @test model_fixed().m != m @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) # Pairs approach. - model_fixed = fix(model, @varname(s) => s) + model_fixed = DynamicPPL.fix(model, @varname(s) => s) @test model_fixed().s == s @test model_fixed().m != m @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) # Dictionary approach. - model_fixed = fix(model, Dict(@varname(s) => s)) + model_fixed = DynamicPPL.fix(model, Dict(@varname(s) => s)) @test model_fixed().s == s @test model_fixed().m != m @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index e117c4bbc..3ba5edfe1 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -7,3 +7,5 @@ @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 end + +# test for `predict` is in `test/model.jl` diff --git a/test/model.jl b/test/model.jl index a19cb29d2..cb1dbc735 100644 --- a/test/model.jl +++ b/test/model.jl @@ -429,4 +429,109 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) end end + + @testset "predict" begin + @testset "with MCMCChains.Chains" begin + DynamicPPL.Random.seed!(100) + + @model function linear_reg(x, y, σ=0.1) + β ~ Normal(0, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end + end + + @model function linear_reg_vec(x, y, σ=0.1) + β ~ Normal(0, 1) + return y ~ MvNormal(β .* x, σ^2 * I) + end + + ground_truth_β = 2 + β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + + xs_test = [10 + 0.1, 10 + 2 * 0.1] + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predictions = DynamicPPL.predict(m_lin_reg_test, β_chain) + + ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) + @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + + # Ensure that `rng` is respected + rng = MersenneTwister(42) + predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2]) + predictions2 = DynamicPPL.predict( + MersenneTwister(42), m_lin_reg_test, β_chain[1:2] + ) + @test all(Array(predictions1) .== Array(predictions2)) + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing) + predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain) + ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) + + @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + + # Multiple chains + multiple_β_chain = MCMCChains.Chains( + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + ) + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) + @test size(multiple_β_chain, 3) == size(predictions, 3) + + for chain_idx in MCMCChains.chains(multiple_β_chain) + ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) + @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + end + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing) + predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) + + for chain_idx in MCMCChains.chains(multiple_β_chain) + ys_pred_vec = vec( + mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1) + ) + @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + end + end + + @testset "with AbstractVector{<:AbstractVarInfo}" begin + @model function linear_reg(x, y, σ=0.1) + β ~ Normal(1, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end + end + + ground_truth_β = 2.0 + # the data will be ignored, as we are generating samples from the prior + xs_train = 1:0.1:10 + ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) + m_lin_reg = linear_reg(xs_train, ys_train) + chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000] + + # chain is generated from the prior + @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 + + xs_test = [10 + 0.1, 10 + 2 * 0.1] + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) + + @test size(predicted_vis) == size(chain) + @test Set(keys(predicted_vis[1])) == + Set([@varname(β), @varname(y[1]), @varname(y[2])]) + # because β samples are from the prior, the std will be larger + @test mean([ + predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) + ]) ≈ 1.0 * xs_test[1] rtol = 0.1 + @test mean([ + predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) + ]) ≈ 1.0 * xs_test[2] rtol = 0.1 + end + end end