From 4449032614b954640edaf217065c6397be9a226e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 28 Oct 2025 00:15:28 +0000 Subject: [PATCH 1/4] Implement returned for AbstractDict; deprecate {values, keys} method --- HISTORY.md | 9 +++++++++ Project.toml | 2 +- docs/src/api.md | 4 ++-- src/model.jl | 47 ++++++++++++++++++++++------------------------- 4 files changed, 34 insertions(+), 28 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 604dcb725..0eb2f6edd 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,14 @@ # DynamicPPL Changelog +## 0.38.3 + +Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`. +Also tweaked the implementation of `returned(::Model, ::NamedTuple)` to accumulate log-probabilities correctly. + +Please note we generally recommend using Dict though, as NamedTuples cannot correctly represent variables with indices / fields on the left-hand side of tildes, like `x[1]` or `x.a`. + +The generic method `returned(::Model, values, keys)` is deprecated and will be removed in the next minor version. + ## 0.38.2 Added a compatibility entry for JET@0.11. diff --git a/Project.toml b/Project.toml index 83e0fea3f..d54f9d1da 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.2" +version = "0.38.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index 80970c0bb..31b7d07da 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -176,11 +176,11 @@ It is possible to manually increase (or decrease) the accumulated log likelihood @addlogprob! ``` -Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`. +Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples), or a single sample represented as a `NamedTuple` or a dictionary of VarNames. ```@docs returned(::DynamicPPL.Model, ::MCMCChains.Chains) -returned(::DynamicPPL.Model, ::NamedTuple) +returned(::DynamicPPL.Model, ::Union{NamedTuple,AbstractDict{<:VarName}}) ``` For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using diff --git a/src/model.jl b/src/model.jl index d6682416b..6bf784eca 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1103,44 +1103,41 @@ function predict end """ returned(model::Model, parameters::NamedTuple) - returned(model::Model, values, keys) - returned(model::Model, values, keys) + returned(model::Model, parameters::AbstractDict{<:VarName}) Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. -If a `NamedTuple` is given, `keys=keys(parameters)` and `values=values(parameters)`. + returned(model::Model, values, keys) + +Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. +This method is deprecated; use the NamedTuple or AbstractDict version instead. # Example ```jldoctest julia> using DynamicPPL, Distributions -julia> @model function demo(xs) - s ~ InverseGamma(2, 3) - m_shifted ~ Normal(10, √s) - m = m_shifted - 10 - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - return (m, ) +julia> @model function demo() + m ~ Normal() + return (mp1 = m + 1,) end demo (generic function with 2 methods) -julia> model = demo(randn(10)); - -julia> parameters = (; s = 1.0, m_shifted=10.0); +julia> model = demo(); -julia> returned(model, parameters) -(0.0,) +julia> returned(model, (; m = 1.0)) +(mp1 = 2.0,) -julia> returned(model, values(parameters), keys(parameters)) -(0.0,) +julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) +(3.0,) ``` """ -function returned(model::Model, parameters::NamedTuple) - fixed_model = fix(model, parameters) - return fixed_model() -end - -function returned(model::Model, values, keys) - return returned(model, NamedTuple{keys}(values)) +function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) + # use `nothing` as the fallback to ensure that any missing parameters cause an error + ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing)) + new_model = setleafcontext(model, ctx) + # We can't use new_model() because that overwrites it with an InitContext of its own. + return first(evaluate!!(new_model, VarInfo())) end +Base.@deprecate returned(model::Model, values, keys) returned( + model, NamedTuple{keys}(values) +) From d448e2783b5fa66828e5137ad587a30555a988a3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 28 Oct 2025 00:24:56 +0000 Subject: [PATCH 2/4] Fix doctest --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 6bf784eca..c4d937007 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1128,7 +1128,7 @@ julia> returned(model, (; m = 1.0)) (mp1 = 2.0,) julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) -(3.0,) +(mp1 = 3.0,) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) From fa472e4f5f0a9ba390d0c55f4f949310893a3bcd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 28 Oct 2025 00:46:25 +0000 Subject: [PATCH 3/4] Add more tests (beyond the doctest) --- test/model.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/model.jl b/test/model.jl index 6ba3bca2a..6da5ea246 100644 --- a/test/model.jl +++ b/test/model.jl @@ -321,6 +321,38 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end + @testset "returned() on NamedTuple / Dict" begin + @model function demo_returned() + a ~ Normal() + b ~ Normal() + return (asq=a^2, bsq=b^2) + end + model = demo_returned() + + @testset "NamedTuple" begin + params = (a=1.0, b=2.0) + results = returned(model, params) + @test results.asq == params.a^2 + @test results.bsq == params.b^2 + # `returned` should error when not all parameters are provided + @test_throws ErrorException returned(model, (; a=1.0)) + @test_throws ErrorException returned(model, (a=1.0, b=missing)) + end + @testset "Dict" begin + params = Dict{VarName,Float64}(@varname(a) => 1.0, @varname(b) => 2.0) + results = returned(model, params) + @test results.asq == params[@varname(a)]^2 + @test results.bsq == params[@varname(b)]^2 + # `returned` should error when not all parameters are provided + @test_throws ErrorException returned( + model, Dict{VarName,Float64}(@varname(a) => 1.0) + ) + @test_throws ErrorException returned( + model, Dict{VarName,Any}(@varname(a) => 1.0, @varname(b) => missing) + ) + end + end + @testset "returned() on `LKJCholesky`" begin n = 10 d = 2 From fac86d2353ceda15fdda2f64899de79a09b71332 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 28 Oct 2025 16:12:24 +0000 Subject: [PATCH 4/4] Remove accs --- HISTORY.md | 4 +--- src/model.jl | 7 +++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 0eb2f6edd..54b40b7e9 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -3,9 +3,7 @@ ## 0.38.3 Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`. -Also tweaked the implementation of `returned(::Model, ::NamedTuple)` to accumulate log-probabilities correctly. - -Please note we generally recommend using Dict though, as NamedTuples cannot correctly represent variables with indices / fields on the left-hand side of tildes, like `x[1]` or `x.a`. +Please note we generally recommend using Dict, as NamedTuples cannot correctly represent variables with indices / fields on the left-hand side of tildes, like `x[1]` or `x.a`. The generic method `returned(::Model, values, keys)` is deprecated and will be removed in the next minor version. diff --git a/src/model.jl b/src/model.jl index c4d937007..edb042ba9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1132,11 +1132,14 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) - # use `nothing` as the fallback to ensure that any missing parameters cause an error + vi = DynamicPPL.setaccs!!(VarInfo(), ()) + # Note: we can't use `fix(model, parameters)` because + # https://github.com/TuringLang/DynamicPPL.jl/issues/1097 + # Use `nothing` as the fallback to ensure that any missing parameters cause an error ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing)) new_model = setleafcontext(model, ctx) # We can't use new_model() because that overwrites it with an InitContext of its own. - return first(evaluate!!(new_model, VarInfo())) + return first(evaluate!!(new_model, vi)) end Base.@deprecate returned(model::Model, values, keys) returned( model, NamedTuple{keys}(values)