Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# DynamicPPL Changelog

## 0.38.3

Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`.
Please note we generally recommend using Dict, as NamedTuples cannot correctly represent variables with indices / fields on the left-hand side of tildes, like `x[1]` or `x.a`.

The generic method `returned(::Model, values, keys)` is deprecated and will be removed in the next minor version.

## 0.38.2

Added a compatibility entry for [email protected].
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.38.2"
version = "0.38.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ It is possible to manually increase (or decrease) the accumulated log likelihood
@addlogprob!
```

Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`.
Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples), or a single sample represented as a `NamedTuple` or a dictionary of VarNames.

```@docs
returned(::DynamicPPL.Model, ::MCMCChains.Chains)
returned(::DynamicPPL.Model, ::NamedTuple)
returned(::DynamicPPL.Model, ::Union{NamedTuple,AbstractDict{<:VarName}})
```

For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using
Expand Down
50 changes: 25 additions & 25 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1103,44 +1103,44 @@ function predict end

"""
returned(model::Model, parameters::NamedTuple)
returned(model::Model, values, keys)
returned(model::Model, values, keys)
returned(model::Model, parameters::AbstractDict{<:VarName})

Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.

If a `NamedTuple` is given, `keys=keys(parameters)` and `values=values(parameters)`.
returned(model::Model, values, keys)

Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.
This method is deprecated; use the NamedTuple or AbstractDict version instead.

# Example
```jldoctest
julia> using DynamicPPL, Distributions

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end
return (m, )
julia> @model function demo()
m ~ Normal()
return (mp1 = m + 1,)
end
demo (generic function with 2 methods)

julia> model = demo(randn(10));

julia> parameters = (; s = 1.0, m_shifted=10.0);
julia> model = demo();

julia> returned(model, parameters)
(0.0,)
julia> returned(model, (; m = 1.0))
(mp1 = 2.0,)

julia> returned(model, values(parameters), keys(parameters))
(0.0,)
julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0))
(mp1 = 3.0,)
```
"""
function returned(model::Model, parameters::NamedTuple)
fixed_model = fix(model, parameters)
return fixed_model()
end

function returned(model::Model, values, keys)
return returned(model, NamedTuple{keys}(values))
function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}})
vi = DynamicPPL.setaccs!!(VarInfo(), ())
# Note: we can't use `fix(model, parameters)` because
# https://github.com/TuringLang/DynamicPPL.jl/issues/1097
# Use `nothing` as the fallback to ensure that any missing parameters cause an error
ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing))
new_model = setleafcontext(model, ctx)
# We can't use new_model() because that overwrites it with an InitContext of its own.
return first(evaluate!!(new_model, vi))
end
Base.@deprecate returned(model::Model, values, keys) returned(
model, NamedTuple{keys}(values)
)
32 changes: 32 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,38 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end

@testset "returned() on NamedTuple / Dict" begin
@model function demo_returned()
a ~ Normal()
b ~ Normal()
return (asq=a^2, bsq=b^2)
end
model = demo_returned()

@testset "NamedTuple" begin
params = (a=1.0, b=2.0)
results = returned(model, params)
@test results.asq == params.a^2
@test results.bsq == params.b^2
# `returned` should error when not all parameters are provided
@test_throws ErrorException returned(model, (; a=1.0))
@test_throws ErrorException returned(model, (a=1.0, b=missing))
end
@testset "Dict" begin
params = Dict{VarName,Float64}(@varname(a) => 1.0, @varname(b) => 2.0)
results = returned(model, params)
@test results.asq == params[@varname(a)]^2
@test results.bsq == params[@varname(b)]^2
# `returned` should error when not all parameters are provided
@test_throws ErrorException returned(
model, Dict{VarName,Float64}(@varname(a) => 1.0)
)
@test_throws ErrorException returned(
model, Dict{VarName,Any}(@varname(a) => 1.0, @varname(b) => missing)
)
end
end

@testset "returned() on `LKJCholesky`" begin
n = 10
d = 2
Expand Down